本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com
蒙特卡罗算法是一个强大的算法,本文讲解如何使用蒙特卡罗算法来解决手写数字识别问题
本文提供实现蒙特卡罗算法识别手写数字的实现代码,并附加一个GUI界面用于手写数字并进行识别
通过本文,可以进一步加强对蒙特卡罗算法的理解和实际应用
本节介绍手写数字识别的问题与数据说明
手写数字数据介绍
matlab2018a中自带的digitimages.mat数据就是写手数字的数据,
共包含3000个样本,每个样本是28*28的图片数据,
不妨每个数字都打印5个样本示例,如下
下面我们使用蒙特卡罗法来识别手写数字
本节讲解如何设计一个蒙特卡罗算法来识别手写数字
蒙特卡罗法用于手写数字识别的算法设计
蒙特卡罗法用于数字识别的算法设计
数字识别实际也是类别识别
所以算法设计与《螃蟹识别》中的流程是一致的,
具体算法设计如下:
将历史样本存储起来作为一个样本库,
当来了一个新样本时,就在样本库各个类别各抽取n个样本
然后判别新样本中与抽出的样本哪个最相似,就判为哪一个类别
如此重复抽取t次,
最后统计t次中,被判为哪个类别的次数最多,就认为样本属于哪个类别
✍️备注:这里我们使用欧氏距离作为相似度的度量,欧氏距离越小,就认为越相似
蒙特卡罗法用于手写数字识别的算法流程图
算法流程如下:
本节通过代码实现蒙特卡罗法识别手写数字,并展示相关结果
蒙特卡罗法识别手写数字-代码实现
依据上述算法设计,编码蒙特卡罗法用于识别手写数字,并配上GUI界面
共四部分代码
👉1. GUI预测界面主函数
👉2. 数据加载函数
👉3. 数据处理函数
👉4. 蒙特卡罗手写数字预测函数
备注:数据处理函数3、预测函数4与非GUI版本的《蒙特卡罗识别手写数字》是一致的
具体代码如下:
%------代码说明:展示蒙特卡罗法预测手写数字(GUI版本) -----------------
% 来自《老饼讲解神经网络》www.bbbdata.com ,matlab版本:2018a
function initDrawGUI()
clc;clear all ;close all ;
global draw_enable % 绘画状态
global draw_x % 绘画的x坐标
global draw_y % 绘画的y坐标
draw_enable = 0; % 初始化绘画状态
init_img_sample() % 初始化数据
% 界面控件
hMainFig = figure('Tag','mainFig','Name','手写数字识别'); % 新建一个界面
set(hMainFig,'WindowButtonDownFcn',@ButttonDownFcn) % 设置界面鼠标按下的回调函数
set(hMainFig,'WindowButtonUpFcn',@ButttonUpFcn) % 设置界面鼠标按上的回调函数
set(hMainFig,'WindowButtonMotionFcn',@ButttonMotionFcn) % 设置界面鼠标移动的回调函数
% 坐标轴控件
haxes = axes('Parent',hMainFig); % 新建一个坐标轴
set(haxes,'position',[0.1 0.2 0.8 0.7 ]); % 设置坐标轴控件的位置
set(haxes, 'XLim', [-3,3], 'YLim', [-2,2],'Box','on'); % 设置坐标轴的范围
haxes.XAxis.Visible = 'off'; % 隐藏坐标轴x轴
haxes.YAxis.Visible = 'off'; % 隐藏坐标轴y轴
% 建一个按钮-用于清空图像
hbuttonClear = uicontrol(...
'Parent',hMainFig,...
'String','清空',...
'position',[150 30 100 40],...
'Callback',@buttonClearCallBack,...
'Style','pushbutton');
% 建一个按钮-用于识别图像
hbutton = uicontrol(...
'Parent',hMainFig,...
'String','识别',...
'position',[350 30 100 40],...
'Callback',@buttonRecCallBack,...
'Style','pushbutton');
% 清除按钮的回调函数(用于清除画面)
function buttonClearCallBack(hObject,eventdata)
cla; % 清除当前图像
end
% 识别按钮的回调函数(用于数字识别)
function buttonRecCallBack(hObject,eventdata)
% 保存图片
tmp_f_handle = figure('visible','off'); % 新建一个figure
tmp_axes = copyobj(haxes,tmp_f_handle); % 将坐标轴内容复制一份
tmp_axes.Title.Visible='off'; % 隐藏标题
set(tmp_axes,'units','default','position','default'); % 新坐标轴的设置
print(tmp_f_handle, '-djpeg', 'tmp_img_for_recognize.jpg'); % 保存图片
delete(tmp_f_handle); % 删除临时figure
img_rgb = imread('tmp_img_for_recognize.jpg'); % 读取图片
img2 = rgb2gray(img_rgb); % 将图片转为灰度图片
img = zeros(size(img2)); % 初始化图片
img(img2==255) = 0; % 将灰度图片中为白色的地方转为0
img(img2~=255) = 255; % 将灰度图片中不是白色的地方转为255
if(all(img(:)==0)||all(img(:)==255)) % 检测是否没有绘画
title(['请先绘画']); % 提示先绘画
drawnow % 显示标题
return % 直接返回
end
img = process_img(img); % 处理图片
title(['识别中....']); % 标记正在识别
drawnow % 显示标题
predict_y = mc_predict_number(img(:)); % 将图片使用蒙特卡罗法进行预测
number = find(predict_y)-1; % 将识别的one-hot转回数字
title(['识别结果',num2str(number)]); % 显示识别结果
end
% 鼠标按下的回调函数
function ButttonDownFcn(src,event)
draw_enable = 1; % 标记当前为绘画状态
p = get(haxes,'CurrentPoint'); % 获取当前的鼠标坐标
draw_x(1) = p(1,1); % 将当前的鼠标x坐标更新为绘画起点的x
draw_y(1) = p(1,2); % 将当前的鼠标y坐标更新为绘画起点的y
end
% 鼠标弹起的回调函数
function ButttonUpFcn(src,event)
draw_enable = 0; % 标记当前为非绘画状态
end
% 鼠标移动回调函数(用于画图)
function ButttonMotionFcn(src,event)
if (draw_enable==1) % 如果处于画画状态
p= get(haxes,'CurrentPoint'); % 获取鼠标位置点
draw_x(2) = p(1,1); % 将当前鼠标点的x作为画图结束点的x
draw_y(2) = p(1,2); % 将当前鼠标点的y作为画图结束点的y
hold on % 保留之前的画图
line(haxes,draw_x,draw_y,'LineWidth',12) % 画图
draw_x(1) = draw_x(2); % 将绘画终点的x更新为下次绘画起点的x
draw_y(1) = draw_y(2); % 将绘画终点的y更新为下次绘画起点的y
end
end
end
图片数据的加载函数init_img_sample.m代码如下:
function init_img_sample()
% 本部分加载手写数字的样本数据
global sample_x % 样本库的x数据
global sample_y % 样本库的y数据
load digitimages.mat % 加载手写数字数据
[h,w,pic_num] = size(images); % 获取手写数字图片的大小与样本数量
sample_x = zeros(20*20,pic_num); % 初始化图片样本数据
for i =1:pic_num % 逐张图片处理
cur_x = process_img(images(:,:,i)); % 处理当前图片
sample_x(:,i)=cur_x(:); % 存储当前图片
end
sample_y = full(ind2vec(Y'+1)); % 将图片对应的数字转为one-hot矩阵
end
图片的处理函数process_img.m代码如下:
% 预处理图片的函数
function deal_img = process_img(img)
deal_img = img>50; % 将值>50的作为1,<50的作为0
deal_img = truncImgsPadding(deal_img); % 对图片上下左右空白处进行裁剪
deal_img = imresize(deal_img,[20,20]); % 将图片转换为20*20的Size
end
% 裁剪图片空白边缘部分
function trunc_img = truncImgsPadding(imgs)
% 裁剪左右两边的空白处
sum_imgs = sum(imgs); % 按列求和
csum_imgs = cumsum(sum_imgs); % 计算累计值
[~,right_idx] = max(csum_imgs); % 根据累计值找出右边第一个非0列
left_idx = find(csum_imgs>0); % 根据累计值找出非0列
left_idx = left_idx(1); % 第一个非0列就是左边第一个非0列
trunc_img = imgs(:,left_idx:right_idx); % 进行左右裁剪
% 裁剪上下的空白处
sum_imgs = sum(trunc_img,2); % 按行求行
csum_imgs = cumsum(sum_imgs); % 计行累计值
[~,bot_idx] = max(csum_imgs); % 根据累计值找出底部第一个非0行
top_idx = find(csum_imgs>0); % 根据累计值找出非0行
top_idx = top_idx(1); % 第一个非0行就是顶部第一个非0行
trunc_img = trunc_img(top_idx:bot_idx,:); % 对上下进行裁剪
end
蒙特卡罗法的预测手写数字的函数mc_predict_number代码如下:
function y = mc_predict_number(x)
% 用蒙特卡罗法判断样本的类别
global sample_x % 样本库的x数据
global sample_y % 样本库的y数据
setdemorandstream(88888);
t = 100; % 裁决次数
n = 200; % 每个类别抽样数量
[class_num,sample_num] = size(sample_y); % 类别个数与样本个数
class_idx = cell(class_num,1); % 初始化各个类别的样本索引
for i = 1:class_num
class_idx{i} = find(sample_y(i,:)); % 找出属于第i个类别的样本索引
end
% 进行抽样裁决x的类别
rs = zeros(class_num,t); % 初始化裁决结果表
for i = 1:t
select_idx = zeros(n*class_num,1); % 本次抽样的样本索引
for j = 1:class_num % 逐类别抽样
cur_class_idx = class_idx{j}; % 属于第i个类别的样本索引
cur_select_idx = randperm(length(cur_class_idx),n); % 随机抽出n个样本
select_idx((j-1)*n+1:j*n) = cur_class_idx(cur_select_idx); % 记录本次抽出的样本索引
end
select_sample = sample_x(:,select_idx); % 抽取出本次抽样的样本
d = sum((select_sample-x).^2); % 计算各个样本与x的距离
[~,win_idx] = min(d); % 找出最小距离的样本作为本次胜出的样本
win_y = sample_y(:,select_idx(win_idx)); % 根据样本的索引找出y
rs(:,i) = win_y; % 记录本次的获胜的y
end
% 统计多次抽样裁决的结果,用于决定最终x的所属类别
win_stat = sum(rs,2); % 统计各个类别胜出的次数
[~,win_idx] = max(win_stat); % 找出哪个类别胜出次数最多
y = zeros(class_num,1); % 初始化x的类别y
y(win_idx) = 1; % 将胜出次数最多的类别,作为x的类别
end
运行结果
将上述四个函数保存后,运行initDrawGUI.m,显示如下界面:
在上面进行书写后,点击识别,结果如下:
可以看到,已经可以正确地预测出所写的数字
效果分析与补充
经笔者测试,有些数字并不是那么的准确,
在《蒙特卡罗手写数字识别》一文中,使用样本库的样本进行测试,准确率达到99.66%,
但在本文的手写板中,笔者发现,准确率并不是那么的高,时不时就会发生预测不准的情况
粗略分析,主要来源于两方面,
1.样本库字体与实际手写不一致
样本库中的样本来源于国外,与我们的手写字体并不是那么的一致
这应该是引起预测错误的最主要原因
2.算法设计较为粗糙
在算法设计中,为了学习的简便性,
只是简单的使用mse函数来评估图片的相似度,
这对于实际应用中的复杂场景来说,过于粗糙,
精细化图片相似度的评估函数后应该能大大提高准确率
End