小编给大家分享一下pytorch中cnn如何识别手写的字并实现自建图片数据,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!
具体如下:
# library # standard library import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader import torchvision import matplotlib.pyplot as plt from PIL import Image import numpy as np # torch.manual_seed(1) # reproducible # Hyper Parameters EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 50 LR = 0.001 # learning rate root = "./mnist/raw/" def default_loader(path): # return Image.open(path).convert('RGB') return Image.open(path) class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader fh.close() def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) img = Image.fromarray(np.array(img), mode='L') if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor()) train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True) test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE) class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # input shape (1, 28, 28) nn.Conv2d( in_channels=1, # input height out_channels=16, # n_filters kernel_size=5, # filter size stride=1, # filter movement/step padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1 ), # output shape (16, 28, 28) nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14) ) self.conv2 = nn.Sequential( # input shape (16, 14, 14) nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) nn.ReLU(), # activation nn.MaxPool2d(2), # output shape (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) output = self.out(x) return output, x # return x for visualization cnn = CNN() print(cnn) # net architecture optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # training and testing for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader b_x = Variable(x) # batch x b_y = Variable(y) # batch y output = cnn(b_x)[0] # cnn output loss = loss_func(output, b_y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if step % 50 == 0: cnn.eval() eval_loss = 0. eval_acc = 0. for i, (tx, ty) in enumerate(test_loader): t_x = Variable(tx) t_y = Variable(ty) output = cnn(t_x)[0] loss = loss_func(output, t_y) eval_loss += loss.data[0] pred = torch.max(output, 1)[1] num_correct = (pred == t_y).sum() eval_acc += float(num_correct.data[0]) acc_rate = eval_acc / float(len(test_data)) print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》
结果如下:
以上是“pytorch中cnn如何识别手写的字并实现自建图片数据”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注创新互联行业资讯频道!
当前标题:pytorch中cnn如何识别手写的字并实现自建图片数据-创新互联
文章分享:https://www.cdcxhl.com/article10/ddpcgo.html
成都网站建设公司_创新互联,为您提供网站设计、域名注册、静态网站、App开发、网站排名、网站内链
声明:本网站发布的内容(图片、视频和文字)以用户投稿、用户转载内容为主,如果涉及侵权请尽快告知,我们将会在第一时间删除。文章观点不代表本网站立场,如需处理请联系客服。电话:028-86922220;邮箱:631063699@qq.com。内容未经允许不得转载,或转载时需注明来源: 创新互联