博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
手写算式的识别与运算
阅读量:4613 次
发布时间:2019-06-09

本文共 34520 字,大约阅读时间需要 115 分钟。

系统:windows

环境:visual studio 2013

语言:c++

相关:OpenCV 2

内容:给一张简单等式的手写图片,识别这个等式具体内容,并且判断是否成立。

声明:

  有挺多借鉴博客的,过了蛮久了,大多忘了,放上一个还有记录的博客:

  

原理:

1)对图片进行切分

  因为这是简单的等式,所以我所用的切分方法比较愚昧。

  1.首先识别如果图片是白底,则反色,设定底色为黑色。

  2.然后对图片创建一个副本,对副本进行平滑化处理以及二值化

  3.由于算式是横向书写的,所以之后对副本的每一个列进行扫描,分割出很多分离的段(得到左右范围),

  4.对于副本中列的每一段进行扫描,找出这一段对应的行中有图像的范围(得到上下范围)。

  5.对于得到上下左右边界的一个元素,映射到原图像中,并且复制出原图像中的内容。

  6.对复制出的图像进行填充,使其长宽相等。

  7.然后重新设定大小,加上一定厚度的边界,得到的便是分割出来的一个数字或符号的块。

2)对数字的识别。

  我用的MNIST数字集,由网站  提供。具体格式网站内有。

  具体做法其实就是读取MNIST数字集合,然后用BP神经网络进行训练。

3)符号识别

  如果有数据集那就好了……因为我找不到,但是大实验占20分,15周就得交,所以我就很蠢地瞎干了……

  具体做法如下(以下说法中图片均为黑底白字):

  1.由第图片分割方法可得,所得到的图片中有内容的部分必然是在中心的,而且图片大小为28*28(为了配合MNIST集)

  2.所以对于减号,其白色像素必然集中在中间11~16行,所以取11~16行像素,若其白色像素占总白色像素90%以上

  3.对于加号,其像素必然是集中在11-16行与11~16列,所以可以类似于减号得到结果,不过要去掉数字“1”与符号“-”的情况

  4.对于除号,由于只有除号是横向分为3个部分的,所以比较好分辨,如果横向可以截成三个部分则标记为除号。

  5.对于乘号,白色像素集中在坐标中y-x或y+x的值在区域[25,30]之间的像素区域中,所以统计一下,如果这块区域中白色像素和占总白色像素的90%以上,则认定为乘号。

  抱歉,符号识别中,以上做法很蠢,准确率也不高,建议使用机器学习。

4)计算

  由于是数字图像的大实验,用的都是很简单的等式,计算我写得很简单,没有什么可以说的。

代码:

Calcu.h是运算部分的头文件

BP.h是BP神经网络头文件

MNIST.h是对MNIST集读取的头文件

Test.h是为了我自己方便写的头文件……

ImgCut.h是预处理图片(切分等)的头文件

main.cpp不解释了……

 

1 //main.cpp  2   3 #define _CRT_SECURE_NO_WARNINGS  4   5 #include 
6 #include
7 #include
8 #include
9 #include
10 #include
11 #include
12 #include "BP.h" 13 #include "MNIST.h" 14 #include "Calcu.h" 15 #include "ImgCut.h" 16 #include "Test.h" 17 18 #define SYMBOL_IMAGE_NUMBER 6000 19 #define NUMBER_IMAGE_NUMBER 60000 20 #define IMAGE_ROWS 28 21 #define IMAGE_COLS 28 22 #define TEST_PREDICT_NUM 15 23 #define TEST_PREDICT_PREFIX "test" 24 #define CUT_NUM 5 25 #define CUT_PREFIX "" 26 #define TYPE_NUM 10 27 #define DIV 256.0 28 #define EXTRA_SPACE 4 //留出的多余空间 29 30 using namespace std; 31 using namespace cv; 32 33 Mat number_images,number_labels; 34 Mat symbol_images, symbol_labels; 35 vector
sample_group; 36 map
label_to_symbol; 37 map
symbol_to_label; 38 BP bp; 39 ImgCut img_cut[CUT_NUM]; 40 string equation_str[CUT_NUM]; 41 Equation equation[CUT_NUM]; 42 43 string makeStr(int num) 44 { 45 if (num == 0) return "0"; 46 string ret = ""; 47 vector
vec_num; 48 vec_num.clear(); 49 int tmp = num; 50 while (tmp) 51 { 52 vec_num.push_back(tmp % 10); 53 tmp /= 10; 54 } 55 for (int i = vec_num.size() - 1; i >= 0; i--) 56 ret += (char)vec_num[i] + '0'; 57 return ret; 58 } 59 60 void getNumberImages() 61 { 62 Mat mat_tmp; 63 string number_image_name = "resource\\train-images.idx3-ubyte"; 64 string number_label_name = "resource\\train-labels.idx1-ubyte"; 65 int number_of_images, number_of_labels, number_of_rows, number_of_cols; 66 number_images = readImages(number_image_name, number_of_images, number_of_rows, number_of_cols); 67 number_labels = readLabels(number_label_name, number_of_labels); 68 Sample now; 69 for (int i = 0; i < NUMBER_IMAGE_NUMBER; i++) 70 { 71 now.in.clear(); now.out.clear(); 72 for (int j = 0; j < number_images.cols; j++) 73 now.in.push_back((double)((int)number_images.at
(i, j)) / DIV); 74 for (int j = 0; j < TYPE_NUM; j++) 75 if (j == (int)number_labels.at
(i, 0)) 76 now.out.push_back(1); 77 else now.out.push_back(0); 78 sample_group.push_back(now); 79 if ((i + 1) % 1000 == 0) 80 printf("number image: %d/60000 completed\n", i + 1); 81 } 82 } 83 84 void test() 85 { 86 /* 87 //测试 Calcu.h 88 char s_a[44] = "132*3=396", s_b[44] = "132*3=196"; 89 Equation equ_a(s_a, strlen(s_a)); 90 Equation equ_b(s_b, strlen(s_b)); 91 */ 92 /* 93 //bp.h 异或测试 94 BP bp; 95 //学习样本 96 vector
sample_in[4], sample_out[4]; 97 Sample sample[4]; 98 sample_in[0].push_back(0); sample_in[0].push_back(0); sample_out[0].push_back(0); 99 sample_in[1].push_back(0); sample_in[1].push_back(1); sample_out[1].push_back(1);100 sample_in[2].push_back(1); sample_in[2].push_back(0); sample_out[2].push_back(1);101 sample_in[3].push_back(1); sample_in[3].push_back(1); sample_out[3].push_back(0);102 for (int i = 0; i < 4; i++)103 sample[i].in = sample_in[i], sample[i].out = sample_out[i];104 vector
sample_group(sample, sample + 4);105 bp.training(sample_group, 1000);106 //测试数据107 vector
test_in[4], test_out[4];108 Sample test[4];109 test_in[0].push_back(-0.1); test_in[0].push_back(0.2);110 test_in[1].push_back(0.15); test_in[1].push_back(0.94);111 test_in[2].push_back(1.5); test_in[2].push_back(-0.04);112 test_in[3].push_back(0.90); test_in[3].push_back(1.23);113 for (int i = 0; i < 4; i++)114 test[i].in = test_in[i];115 vector
test_group(test, test + 4);116 bp.predict(test_group);117 for (int i = 0; i < test_group.size(); i++)118 {119 for (int j = 0; j < test_group[i].in.size(); j++)120 printf("%lf\t", test_group[i].in[j]);121 printf(" : ");122 for (int j = 0; j < test_group[i].out.size(); j++)123 printf("%lf\t", test_group[i].out[j]);124 puts("");125 }126 pause;127 */128 }129 130 void train()131 {132 sample_group.clear();133 getNumberImages();134 puts("tranning start");135 bp.training(sample_group, 300000);136 bp.write();137 }138 139 char getResult(double xret)140 {141 int num = (int)(xret + 0.5);142 if (num >= 0 && num <= 9)143 return num + '0';144 char sym_lst[5] = { '+', '-', '*', '/', '=' };145 return sym_lst[14 - num];146 }147 148 // tag: '=' : 10 , '/' : 11 , '*' : 12 , '-' : 13 , '+' : 14149 int getTag(Mat tmp_mat)150 {151 int ret = -1, tmp;152 int i, j, row_id, col_id;153 int bas = 28;154 int bound1[2] = { 11, 16 }, bound2[2] = { 25, 30 }, cnt14, cnt12, cnt0, cnt13, xcnt14;155 int cut_num = 0, cnt_cols[IMAGE_COLS + EXTRA_SPACE], cnt_rows[IMAGE_ROWS + EXTRA_SPACE];156 bool ctn_flag = false;157 threshold(tmp_mat, tmp_mat, 255 / 2, 255, THRESH_BINARY);158 // showPic(tmp_mat);159 memset(cnt_rows, 0, sizeof(cnt_rows));160 memset(cnt_cols, 0, sizeof(cnt_cols));161 for (int i = 0; i < tmp_mat.rows; i++) //判断能被截成几段162 {163 for (int j = 0; j < tmp_mat.cols; j++)164 if (tmp_mat.at
(i, j) == 255) cnt_cols[j]++, cnt_rows[i]++;165 if (cnt_rows[i])166 {167 if (ctn_flag == false)168 ctn_flag = true, cut_num++;169 }170 else ctn_flag = false;171 }172 if (cut_num == 3) return 11; //结果为"÷" 除号173 if (cut_num == 2) return 10; //结果为"=" 等于号174 cnt0 = cnt12 = cnt13 = cnt14 = xcnt14 = 0;175 for (int i = 0; i < tmp_mat.rows; i++) //统计各个区域的白色像素的量176 for (int j = 0; j < tmp_mat.cols; j++)177 if (tmp_mat.at
(i, j) == 255)178 {179 cnt0++;180 if ((i >= bound1[0] && i <= bound1[1]) || (j >= bound1[0] && j <= bound1[1]))181 cnt14++;182 if ((j - i + bas >= bound2[0] && j - i + bas <= bound2[1]) || (i + j >= bound2[0] && i + j <= bound2[1]))183 cnt12++;184 if (i >= bound1[0] && i <= bound1[1])185 cnt13++;186 if (j >= bound1[0] && j <= bound1[1])187 xcnt14++;188 }189 /*190 for (int i = 0; i < tmp_mat.rows; i++)191 {192 for (int j = 0; j < tmp_mat.cols; j++)193 printf("%03d ", tmp_mat.at
(i, j));194 cout << endl;195 }196 */197 printf("cnt12 = %d cnt13 = %d cnt14 = %d cnt0 = %d xcnt14 = %d\n", cnt12, cnt13, cnt14, cnt0, xcnt14);198 if (cnt13 > cnt0*0.9) //结果为'-' 减号199 return 13;200 if (cnt14 > cnt0*0.9 && xcnt14 < cnt0*0.9) //结果为'+' 加号201 return 14;202 if (cnt12 > cnt0*0.9) //结果为'x' 乘号203 return 12;204 return ret = -1;205 }206 207 void predict()208 {209 double tmp_mx;210 int id_mx;211 Mat now_img, test_images, test_labels, mat_tmp;212 Sample now;213 int tag;214 //测试1 MNIST测试集测试215 /*216 sample_group.clear();217 string test_image_name = "resource\\t10k-images.idx3-ubyte";218 string test_label_name = "resource\\t10k-labels.idx1-ubyte";219 int number_of_images, number_of_labels, number_of_rows, number_of_cols, corret_number = 0;220 test_images = readImages(test_image_name, number_of_images, number_of_rows, number_of_cols);221 test_labels = readLabels(test_label_name, number_of_labels);222 mat_tmp.create(number_of_rows, number_of_cols, CV_8UC1);223 for (int i = 0; i < number_of_images; i++)224 {225 now.in.clear(); now.out.clear();226 // for (int p = 0; p < number_of_rows; p++)227 // for (int q = 0; q < number_of_cols; q++)228 // mat_tmp.at
(p, q) = (int)test_images.at
(i, q + p*(number_of_cols));229 // showPic(mat_tmp);230 for (int j = 0; j < number_of_rows*number_of_cols; j++)231 now.in.push_back((int)test_images.at
(i, j)/DIV);232 sample_group.push_back(now);233 }234 bp.predict(sample_group);235 for (int i = 0; i < number_of_images; i++)236 {237 printf("id=%d:\n", i);238 tmp_mx = -1;239 for (int j = 0; j < sample_group[i].out.size(); j++)240 {241 printf("%lf ", sample_group[i].out[j]);242 if (tmp_mx < sample_group[i].out[j]) tmp_mx = sample_group[i].out[j], id_mx = j;243 }244 puts("");245 printf("answer is %d\n", id_mx);246 printf("true answer is %d\n", (int)test_labels.at
(i, 0));247 if (id_mx == (int)test_labels.at
(i, 0))248 corret_number++;249 }250 printf("corret rate is %.4lf\n", 1.0*corret_number / number_of_images);251 pause;252 //*/253 //测试2254 /*255 //cout << "!" << endl;256 sample_group.clear();257 string dir = "predict\\sav\\";258 for (int i = 0; i < TEST_PREDICT_NUM; i++)259 {260 now_img = imread(dir + TEST_PREDICT_PREFIX + makeStr(i) + ".jpg");261 cvtColor(now_img, now_img, CV_BGR2GRAY);262 // threshold(now_img, now_img, 255 / 3 * 2, 255, THRESH_BINARY);263 now.in.clear(); now.out.clear();264 for (int p = 0; p < now_img.rows; p++)265 for (int q = 0; q < now_img.cols; q++)266 now.in.push_back(((int)(now_img.at
(p, q))) / DIV);267 sample_group.push_back(now);268 }269 bp.predict(sample_group);270 for (int i = 0; i < TEST_PREDICT_NUM; i++)271 {272 printf("id=%d:\n", i);273 now_img = imread(dir + TEST_PREDICT_PREFIX + makeStr(i) + ".jpg");274 cvtColor(now_img, now_img, CV_BGR2GRAY);275 tag = getTag(now_img);276 tmp_mx = -1;277 for (int j = 0; j < sample_group[i].out.size(); j++)278 {279 printf("%lf ", sample_group[i].out[j]);280 if (tmp_mx < sample_group[i].out[j]) tmp_mx = sample_group[i].out[j], id_mx = j;281 }282 if (tag != -1)283 id_mx = tag;284 puts("");285 printf("answer is %d\n", id_mx);286 }287 pause;288 //*/289 //predict部分290 string cut_part_name = "predict\\cut_part\\";291 for (int i = 0; i < CUT_NUM; i++)292 {293 sample_group.clear();294 equation_str[i] = "";295 for (int j = 0; j < img_cut[i].pec_num; j++)296 {297 now.in.clear(); now.out.clear();298 now_img = img_cut[i].pec[j].clone();299 // showPic(now_img);300 for (int p = 0; p < img_cut[i].pec[j].rows; p++)301 for (int q = 0; q < img_cut[i].pec[j].cols; q++)302 now.in.push_back(((int)(img_cut[i].pec[j].at
(p, q))) / DIV);303 sample_group.push_back(now);304 }305 bp.predict(sample_group);306 for (int j = 0; j < img_cut[i].pec_num; j++)307 {308 // showPic(img_cut[i].pec[j]);309 imwrite(cut_part_name + makeStr(i) + "-" + makeStr(j) + ".jpg", img_cut[i].pec[j]);310 tag = getTag(img_cut[i].pec[j]);311 tmp_mx = -1;312 for (int p = 0; p < sample_group[j].out.size(); p++)313 {314 printf("%lf ", sample_group[j].out[p]);315 if (tmp_mx < sample_group[j].out[p]) tmp_mx = sample_group[j].out[p], id_mx = p;316 }317 if (tag != -1)318 id_mx = tag;319 puts("");320 equation_str[i] += getResult(id_mx);321 }322 printf("equation id=%d : %s\n", i, equation_str[i].c_str());323 }324 }325 326 void init()327 {328 label_to_symbol.clear(); symbol_to_label.clear();329 label_to_symbol['+'] = -5; label_to_symbol['-'] = -4; label_to_symbol['*'] = -3; label_to_symbol['/'] = -2; label_to_symbol['='] = -1;330 symbol_to_label[-5] = '+'; symbol_to_label[-4] = '-'; symbol_to_label[-3] = '*'; symbol_to_label[-2] = '/'; symbol_to_label[-1] = '=';331 }332 333 void cut()334 {335 string dir = "predict\\";336 Mat now_img;337 for (int i = 0; i < CUT_NUM; i++)338 {339 now_img = imread(dir + CUT_PREFIX + makeStr(i) + ".jpg");340 img_cut[i] = ImgCut(now_img);341 for (int j = 0; j < img_cut[i].pec_num; j++)342 {343 printf("dealling with id = %d\n", j);344 // showPic(img_cut[i].pec[j]);345 }346 }347 }348 349 void calcu()350 {351 for (int i = 0; i < CUT_NUM; i++)352 equation[i] = Equation(equation_str[i].c_str(), equation_str[i].length());353 }354 355 int main()356 {357 int wh;358 init();359 // test();360 cut();361 bp = BP();362 puts("\ninput\n 0: load last trainning data\n 1: restart tranning\n 2: load last tranning data and continue trainning\n");363 scanf("%d", &wh);364 if (wh == 1) train();365 else if (wh == 0) bp.load();366 else if (wh == 2) bp.load(), train();367 else return 0 & puts("error");368 predict();369 calcu();370 pause;371 return 0;372 }
main.cpp
1 //Calcu.h 2  3 #ifndef CALCU_H 4 #define CALCU_H 5  6 #include 
7 #include
8 #include
9 #include
10 #include
11 12 #define MAX_LEN 4413 14 using namespace std;15 16 class Equation17 {18 public:19 char pre[MAX_LEN];20 bool ok, ans_flag, result;21 int len_pre, ans, true_ans;22 23 int getNum(const char *str, int len)24 {25 int ret = 0, bas = 1;26 for (int i = len - 1; i >= 0; i--)27 ret += bas*(str[i] - '0'), bas *= 10;28 return ret;29 }30 31 bool getAns(const char *str, int len)32 {33 for (int i = 0; i < len; i++)34 if (str[i] > '9' || str[i] < '0')35 return false;36 ans = getNum(str, len);37 return true;38 }39 40 bool isSym(char x)41 {42 if (x == '-' || x == '+' || x == '*' || x == '/') return true;43 return false;44 }45 46 bool calcu()47 {48 int tmp, plc, a, b, c;49 if (len_pre < 3 || isSym(pre[0]) || isSym(pre[len_pre - 1])) return false;50 tmp = 0;51 for (int i = 0; i < len_pre; i++) if (isSym(pre[i])) tmp++, plc = i;52 if (tmp != 1) return false;53 a = getNum(pre + 0, plc);54 b = getNum(pre + plc + 1, len_pre - plc - 1);55 if (pre[plc] == '+') c = a + b;56 else if (pre[plc] == '-') c = a - b;57 else if (pre[plc] == '/') c = a / b;58 else if (pre[plc] == '*') c = a * b;59 true_ans = c;60 return true;61 }62 63 Equation(){}64 65 Equation(const char *str, int len)66 {67 result = false;68 len_pre = 0;69 for (int i = 0; i < len; i++)70 if (str[i] != '=')71 pre[len_pre++] = str[i];72 else break;73 ans_flag = getAns(str + len_pre + 1, len - len_pre - 1);74 if (ans_flag == false) puts("答案识别发生错误");75 ok = calcu();76 if (ok == false) puts("算式识别错误");77 else (ans_flag & (result = (ans == true_ans))) ? puts("答案正确") : puts("答案错误");78 }79 80 };81 82 #endif
Calcu.h
1 //BP.h  2   3 #ifndef BP_H  4 #define BP_H  5   6 #include 
7 #include
8 #include
9 #include
10 #include
11 #include
12 #include
13 #include
14 #include "Test.h" 15 16 #define IN_NODE_NUM 28*28 //输入节点数 17 #define HIDDEN_NODE_NUM 16 //隐含节点数 18 #define HIDDEN_LAYER_NUM 1 //隐含层数 19 #define OUT_NODE_NUM 10 //输出节点数 20 #define LEARNING_RATE 0.3 //学习速率 21 #define RD rand()%1000 22 #define POSI 1000 //选择样本概率千分之1000 23 #define MAX_RAND_SEG (int)144e4 24 25 using namespace std; 26 27 inline double xrand() // 0.1 ~ -0.1 28 { 29 return ((2.0*(double)rand() / RAND_MAX) - 1) / 10.0; 30 } 31 32 inline double sigmoid(double x) //sigmoid 33 { 34 double ret = 1 / (1 + exp(-x)); 35 return ret; 36 } 37 38 struct InputNode 39 { 40 double value; //固定输入值 41 vector
weight; //到首个隐含层权值 42 vector
wdelta_sum; //到首个隐含层权值的delta值累积 43 44 InputNode() 45 { 46 weight.clear(); 47 wdelta_sum.clear(); 48 } 49 }; 50 51 struct OutputNode 52 { 53 double value; 54 double delta; //与正确值之间的偏差值 55 double rightout; //正确值 56 double bias; //偏移量 57 double bdelta_sum; //bias的delta累积 58 59 OutputNode(){ } 60 }; 61 62 struct HiddenNode 63 { 64 double value; 65 double delta; //BP推导出的delta 66 double bias; //偏移量 67 double bdelta_sum; //bias的delta值累积 68 vector
weight; //对于下一层的每个节点的权值 69 vector
wdelta_sum; //对于下一层的权值delta累积 70 71 HiddenNode() 72 { 73 weight.clear(); 74 wdelta_sum.clear(); 75 } 76 }; 77 78 struct RandSegNode 79 { 80 int id, val; 81 } rand_seg[MAX_RAND_SEG]; 82 83 struct Sample 84 { 85 vector
in, out; 86 }; 87 88 bool cmpRandSeg(RandSegNode a,RandSegNode b) 89 { 90 return a.val < b.val; 91 } 92 93 class BP 94 { 95 public: 96 double error; 97 InputNode* input_layer[IN_NODE_NUM]; 98 OutputNode* output_layer[OUT_NODE_NUM]; 99 HiddenNode* hidden_layer[HIDDEN_LAYER_NUM][HIDDEN_NODE_NUM];100 101 void load()102 {103 string file_name = "data\\data.txt";104 ifstream infile(file_name, ios::in);105 for (int i = 0; i < IN_NODE_NUM; i++)106 for (int j = 0; j < HIDDEN_NODE_NUM; j++)107 infile >> input_layer[i]->weight[j];108 for (int k = 0; k < HIDDEN_LAYER_NUM - 1; k++)109 for (int i = 0; i < HIDDEN_NODE_NUM; i++)110 for (int j = 0; j < HIDDEN_NODE_NUM; j++)111 infile >> hidden_layer[k][i]->weight[j];112 for (int i = 0; i < HIDDEN_NODE_NUM; i++)113 for (int j = 0; j < OUT_NODE_NUM; j++)114 infile >> hidden_layer[HIDDEN_LAYER_NUM - 1][i]->weight[j];115 for (int k = 0; k < HIDDEN_LAYER_NUM; k++)116 for (int i = 0; i < HIDDEN_NODE_NUM; i++)117 infile >> hidden_layer[k][i]->bias;118 for (int i = 0; i < OUT_NODE_NUM; i++)119 infile >> output_layer[i]->bias;120 }121 122 void write()123 {124 string file_name = "data\\data.txt";125 ofstream outfile(file_name, ios::out);126 for (int i = 0; i < IN_NODE_NUM; i++)127 for (int j = 0; j < HIDDEN_NODE_NUM; j++)128 outfile << input_layer[i]->weight[j] << ' ';129 for (int k = 0; k < HIDDEN_LAYER_NUM - 1; k++)130 for (int i = 0; i < HIDDEN_NODE_NUM; i++)131 for (int j = 0; j < HIDDEN_NODE_NUM; j++)132 outfile << hidden_layer[k][i]->weight[j] << ' ';133 for (int i = 0; i < HIDDEN_NODE_NUM; i++)134 for (int j = 0; j < OUT_NODE_NUM; j++)135 outfile << hidden_layer[HIDDEN_LAYER_NUM - 1][i]->weight[j] << ' ';136 for (int k = 0; k < HIDDEN_LAYER_NUM; k++)137 for (int i = 0; i < HIDDEN_NODE_NUM; i++)138 outfile << hidden_layer[k][i]->bias << ' ';139 for (int i = 0; i < OUT_NODE_NUM; i++)140 outfile << output_layer[i]->bias << ' ';141 }142 143 BP()144 {145 srand((unsigned)time(NULL));146 error = 100;147 //初始化输入层148 for (int i = 0; i < IN_NODE_NUM; i++)149 {150 input_layer[i] = new InputNode();151 for (int j = 0; j < HIDDEN_NODE_NUM; j++)152 {153 input_layer[i]->weight.push_back(xrand());154 input_layer[i]->wdelta_sum.push_back(0);155 }156 }157 //初始化隐藏层158 for (int i = 0; i < HIDDEN_LAYER_NUM; i++)159 {160 if (i == HIDDEN_LAYER_NUM - 1)161 {162 for (int j = 0; j < HIDDEN_NODE_NUM;j++)163 { 164 hidden_layer[i][j] = new HiddenNode();165 hidden_layer[i][j]->bias = 0;166 for (int k = 0; k < OUT_NODE_NUM; k++)167 {168 hidden_layer[i][j]->weight.push_back(xrand());169 hidden_layer[i][j]->wdelta_sum.push_back(0);170 }171 }172 }173 else174 {175 for (int j = 0; j < HIDDEN_NODE_NUM; j++)176 {177 hidden_layer[i][j] = new HiddenNode();178 hidden_layer[i][j]->bias = 0;179 for (int k = 0; k < HIDDEN_NODE_NUM; k++)180 hidden_layer[i][j]->weight.push_back(xrand());181 }182 }183 }184 //初始化输出层185 for (int i = 0; i < OUT_NODE_NUM; i++)186 {187 output_layer[i] = new OutputNode();188 output_layer[i]->bias = 0;189 }190 }191 192 void forwardPropagationEpoc() //单个样本 向前传播193 {194 //输入层->隐含层 隐含层->隐含层195 for (int i = 0; i < HIDDEN_LAYER_NUM; i++)196 {197 if (i == 0)198 {199 for (int j = 0; j < HIDDEN_NODE_NUM; j++)200 {201 double sum = 0;202 for (int k = 0; k < IN_NODE_NUM; k++)203 sum += input_layer[k]->value * input_layer[k]->weight[j];204 sum += hidden_layer[i][j]->bias;205 hidden_layer[i][j]->value = sigmoid(sum);206 }207 }208 else209 {210 for (int j = 0; j < HIDDEN_NODE_NUM; j++)211 {212 double sum = 0;213 for (int k = 0; k < HIDDEN_NODE_NUM; k++)214 sum += hidden_layer[i - 1][k]->value*hidden_layer[i - 1][k]->weight[j];215 sum += hidden_layer[i][j]->bias;216 hidden_layer[i][j]->value = sigmoid(sum);217 }218 }219 }220 //隐含层->输出层221 for (int i = 0; i < OUT_NODE_NUM; i++)222 {223 double sum = 0;224 for (int j = 0; j < HIDDEN_NODE_NUM; j++)225 sum += hidden_layer[HIDDEN_LAYER_NUM - 1][j]->value * hidden_layer[HIDDEN_LAYER_NUM - 1][j]->weight[i];226 sum += output_layer[i]->bias;227 output_layer[i]->value = sigmoid(sum);228 }229 }230 231 void backPropagationEpoc() //单个样本 向后传播232 {233 //输出层 计算delta234 for (int i = 0; i < OUT_NODE_NUM; i++)235 {236 double tmp = output_layer[i]->rightout - output_layer[i]->value;237 error += tmp*tmp / 2;238 output_layer[i]->delta = tmp*(1 - output_layer[i]->value)*output_layer[i]->value;239 }240 //隐含层 计算delta241 for (int i = HIDDEN_LAYER_NUM - 1; i >= 0; i--)242 {243 if (i == HIDDEN_LAYER_NUM - 1)244 {245 for (int j = 0; j < HIDDEN_NODE_NUM; j++)246 {247 double sum = 0;248 for (int k = 0; k < OUT_NODE_NUM; k++)249 sum += output_layer[k]->delta*hidden_layer[i][j]->weight[k];250 hidden_layer[i][j]->delta = sum*(1 - hidden_layer[i][j]->value)*hidden_layer[i][j]->value;251 }252 }253 else254 {255 for (int j = 0; j < HIDDEN_LAYER_NUM; j++)256 {257 double sum = 0;258 for (int k = 0; k < HIDDEN_NODE_NUM; k++)259 sum += hidden_layer[i + 1][k]->delta*hidden_layer[i][j]->weight[k];260 hidden_layer[i][j]->delta = sum*(1 - hidden_layer[i][j]->value)*hidden_layer[i][j]->value;261 }262 }263 }264 //输入层 更新 wdelta_sum265 for (int i = 0; i < IN_NODE_NUM; i++)266 for (int j = 0; j < HIDDEN_NODE_NUM; j++)267 input_layer[i]->wdelta_sum[j] += input_layer[i]->value*hidden_layer[0][j]->delta;268 //隐含层 更新 wdelta_sum 和 bdelta_sum269 for (int i = 0; i < HIDDEN_LAYER_NUM; i++)270 {271 if (i == HIDDEN_LAYER_NUM - 1)272 {273 for (int j = 0; j < HIDDEN_NODE_NUM; j++)274 {275 hidden_layer[i][j]->bdelta_sum += hidden_layer[i][j]->delta;276 for (int k = 0; k < OUT_NODE_NUM; k++)277 hidden_layer[i][j]->wdelta_sum[k] += hidden_layer[i][j]->value*output_layer[k]->delta;278 }279 }280 else281 {282 for (int j = 0; j < HIDDEN_NODE_NUM; j++)283 {284 hidden_layer[i][j]->bdelta_sum += hidden_layer[i][j]->delta;285 for (int k = 0; k < HIDDEN_NODE_NUM; k++)286 hidden_layer[i][j]->wdelta_sum[k] += hidden_layer[i][j]->value*hidden_layer[i + 1][k]->delta;287 }288 }289 }290 //输出层 更新 bdelta_sum291 for (int i = 0; i < OUT_NODE_NUM; i++)292 output_layer[i]->bdelta_sum += output_layer[i]->delta;293 }294 295 void training(vector
sample_group, int cnt_bound) //更新weight,bias296 {297 int sample_num = sample_group.size();298 for (int i = 0; i < sample_num; i++)299 rand_seg[i].id = i, rand_seg[i].val = rand();300 sort(rand_seg, rand_seg + sample_num, cmpRandSeg);301 // double error_bound;302 // double last_error = -1;303 int cnt = 0;304 int now_id;305 while (cnt < cnt_bound)306 {307 // last_error = error;308 error = 0;309 for (int i = 0; i < IN_NODE_NUM; i++)310 input_layer[i]->wdelta_sum.assign(input_layer[i]->wdelta_sum.size(), 0);311 for (int i = 0; i < HIDDEN_LAYER_NUM; i++)312 for (int j = 0; j < HIDDEN_NODE_NUM; j++)313 {314 hidden_layer[i][j]->wdelta_sum.assign(hidden_layer[i][j]->wdelta_sum.size(), 0);315 hidden_layer[i][j]->bdelta_sum = 0;316 }317 for (int i = 0; i < OUT_NODE_NUM; i++)318 output_layer[i]->bdelta_sum = 0;319 now_id = rand_seg[cnt%sample_group.size()].id;320 setInput(sample_group[now_id].in);321 setOutput(sample_group[now_id].out);322 forwardPropagationEpoc();323 backPropagationEpoc();324 //输出层反向传递 更新 weight325 for (int i = 0; i < IN_NODE_NUM; i++)326 for (int j = 0; j < HIDDEN_NODE_NUM; j++)327 input_layer[i]->weight[j] += LEARNING_RATE*input_layer[i]->wdelta_sum[j];328 //隐含层反向传递 更新 weight 和 bias 329 for (int i = 0; i < HIDDEN_LAYER_NUM; i++)330 {331 if (i == HIDDEN_LAYER_NUM - 1)332 {333 for (int j = 0; j < HIDDEN_NODE_NUM; j++)334 {335 hidden_layer[i][j]->bias += LEARNING_RATE*hidden_layer[i][j]->bdelta_sum;336 for (int k = 0; k < OUT_NODE_NUM; k++)337 hidden_layer[i][j]->weight[k] += LEARNING_RATE*hidden_layer[i][j]->wdelta_sum[k];338 }339 }340 else341 {342 for (int j = 0; j < HIDDEN_NODE_NUM; j++)343 {344 hidden_layer[i][j]->bias += LEARNING_RATE*hidden_layer[i][j]->bdelta_sum;345 for (int k = 0; k < HIDDEN_NODE_NUM; k++)346 hidden_layer[i][j]->weight[k] += LEARNING_RATE*hidden_layer[i][j]->wdelta_sum[k];347 }348 }349 }350 //输出层反向传递 更新bias351 for (int i = 0; i < OUT_NODE_NUM; i++)352 output_layer[i]->bias += LEARNING_RATE*output_layer[i]->bdelta_sum;353 if (++cnt % 10000 == 0)354 {355 printf("turn %d/%d finished \n", cnt, cnt_bound);356 printf("training error: %lf\n", error);357 }358 }359 }360 361 void predict(vector
& test_group) //神经网络预测362 {363 int test_num = test_group.size();364 for (int id = 0; id < test_num; id++)365 {366 test_group[id].out.clear();367 setInput(test_group[id].in);368 //输入层->隐含层 隐含层->隐含层 正向传播369 for (int i = 0; i < HIDDEN_LAYER_NUM; i++)370 {371 if (i == 0)372 {373 for (int j = 0; j < HIDDEN_NODE_NUM; j++)374 {375 double sum = 0;376 for (int k = 0; k < IN_NODE_NUM; k++)377 sum += input_layer[k]->value*input_layer[k]->weight[j];378 sum += hidden_layer[i][j]->bias;379 hidden_layer[i][j]->value = sigmoid(sum);380 }381 }382 else383 {384 for (int j = 0; j < HIDDEN_NODE_NUM; j++)385 {386 double sum = 0;387 for (int k = 0; k < HIDDEN_NODE_NUM; k++)388 sum += hidden_layer[i - 1][k]->value*hidden_layer[i - 1][k]->weight[j];389 sum += hidden_layer[i][j]->bias;390 hidden_layer[i][j]->value = sigmoid(sum);391 }392 }393 }394 for (int i = 0; i < OUT_NODE_NUM; i++)395 {396 double sum = 0;397 for (int j = 0; j < HIDDEN_NODE_NUM; j++)398 sum += hidden_layer[HIDDEN_LAYER_NUM - 1][j]->value*hidden_layer[HIDDEN_LAYER_NUM - 1][j]->weight[i];399 sum += output_layer[i]->bias;400 output_layer[i]->value = sigmoid(sum);401 test_group[id].out.push_back(output_layer[i]->value);402 }403 }404 }405 406 void setInput(vector
sample_in) //设置学习样本输入407 {408 for (int i = 0; i < IN_NODE_NUM; i++)409 input_layer[i]->value = sample_in[i];410 }411 412 void setOutput(vector
sample_out) //设置学习样本输出413 {414 for (int i = 0; i < OUT_NODE_NUM; i++)415 output_layer[i]->rightout = sample_out[i];416 }417 };418 419 #endif
BP.h
1 //MNIST.h  2   3 #ifndef MNIST_H  4 #define MNIST_H  5   6 #include 
7 #include
8 #include
9 #include "BP.h" 10 11 #define MAGIC_NUMBER_OF_IMAGE 2051 12 #define MAGIC_NUMBER_OF_LABEL 2049 13 14 using namespace std; 15 using namespace cv; 16 17 struct MNISTImageFileHeader //MNIST图片结构体 18 { 19 unsigned char magic_number[4]; 20 unsigned char number_of_images[4]; 21 unsigned char number_of_rows[4]; 22 unsigned char number_of_colums[4]; 23 }; 24 25 struct MNISTLabelFileHeader //MNIST标签结构体 26 { 27 unsigned char magic_number[4]; 28 unsigned char number_of_labels[4]; 29 }; 30 31 int converCharArrayToInt(unsigned char* array, int length_of_array) 32 { 33 if (length_of_array < 0) 34 return -1; 35 int result = static_cast
(array[0]); 36 for (int i = 1; i < length_of_array; i++) 37 result = (result << 8) + array[i]; 38 return result; 39 } 40 41 bool isImageDateFile(unsigned char* magic_number, int length_of_array) 42 { 43 int magic_number_of_image = converCharArrayToInt(magic_number, length_of_array); 44 if (magic_number_of_image == MAGIC_NUMBER_OF_IMAGE) 45 return true; 46 return false; 47 } 48 49 bool isLabelDateFile(unsigned char* magic_number, int length_of_array) 50 { 51 int magic_number_of_label = converCharArrayToInt(magic_number, length_of_array); 52 if (magic_number_of_label == MAGIC_NUMBER_OF_LABEL) 53 return true; 54 return false; 55 } 56 57 Mat readData(fstream &data_file, int number_of_datas, int data_size_in_bytes) 58 { 59 Mat data_mat; 60 if (data_file.is_open()) 61 { 62 int all_data_size_in_bytes = data_size_in_bytes*number_of_datas; 63 unsigned char* tmp_data = new unsigned char[all_data_size_in_bytes]; 64 data_file.read((char*)tmp_data, all_data_size_in_bytes); 65 data_mat = Mat(number_of_datas, data_size_in_bytes, CV_8UC1, tmp_data).clone(); 66 delete []tmp_data; 67 data_file.close(); 68 } 69 return data_mat; 70 } 71 72 Mat readImageData(fstream &image_data_file, int number_of_images) 73 { 74 int image_size_in_bytes = 28 * 28; 75 return readData(image_data_file, number_of_images, image_size_in_bytes); 76 } 77 78 Mat readLabelData(fstream &label_data_file, int number_of_labels) 79 { 80 int label_size_in_bytes = 1; 81 return readData(label_data_file, number_of_labels, label_size_in_bytes); 82 } 83 84 Mat readImages(string &file_name, int &number_of_images, int &number_of_rows, int &number_of_cols) 85 { 86 fstream file(file_name.c_str(), std::ios_base::in | std::ios_base::binary); 87 if (!file.is_open()) 88 return Mat(); 89 MNISTImageFileHeader file_header; 90 file.read((char*)(&file_header), sizeof(file_header)); 91 if (!isImageDateFile(file_header.magic_number, 4)) 92 return Mat(); 93 number_of_images = converCharArrayToInt(file_header.number_of_images, 4); 94 number_of_rows = converCharArrayToInt(file_header.number_of_rows, 4); 95 number_of_cols = converCharArrayToInt(file_header.number_of_colums, 4); 96 return readImageData(file, number_of_images); 97 } 98 99 Mat readLabels(string &file_name, int &number_of_images)100 {101 fstream file(file_name.c_str(), ios_base::in | ios_base::binary);102 if (!file.is_open())103 return Mat();104 MNISTLabelFileHeader file_header;105 file.read((char*)(&file_header), sizeof(file_header));106 if (!isLabelDateFile(file_header.magic_number, 4))107 return Mat();108 number_of_images = converCharArrayToInt(file_header.number_of_labels, 4);109 return readLabelData(file, number_of_images);110 }111 112 #endif
MNIST.h
1 //Test.h 2  3 #ifndef TEST_H 4 #define TEST_H 5  6 #include 
7 #include
8 #include
9 #include
10 #include
11 #include
12 #include
13 #include
14 15 #define pause system("pause")16 17 using namespace std;18 using namespace cv;19 20 void showPic(Mat tmp_mat)21 {22 imshow("x", tmp_mat);23 cvWaitKey(0);24 pause;25 }26 27 #endif
Test.h
1 //ImgCut.h  2   3 #ifndef IMG_CUT_H  4 #define IMG_CUT_H  5   6 #include 
7 #include
8 #include
9 #include
10 #include
11 #include "MNIST.h" 12 #include "Calcu.h" 13 #include "Test.h" 14 15 using namespace std; 16 using namespace cv; 17 18 const int M = 1e4 + 44; 19 const int N = 14; 20 const int INF = 1e9 + 7; 21 22 class ImgCut 23 { 24 public: 25 bool tag[M], xtag[M]; 26 Mat img, pec[M], tmp_mat, img_sav; 27 int pec_num; 28 29 int chg(int val) 30 { 31 int ret; 32 if (val <= 40) return 0; 33 if (val >= 215) return 255; 34 return (int)((51.0 / 35)*val - (408.0) / 7); 35 } 36 37 void deal(int li, int ri) 38 { 39 Mat tmp; 40 int upi = INF, dwi, new_sz, sz_bas; 41 memset(xtag, 0, sizeof(xtag)); 42 for (int i = 0; i < img.rows; i++) 43 for (int j = li; j <= ri; j++) 44 if ((int)(img.at
(i, j)) == 255) 45 xtag[i] = 1; 46 for (int i = 0; i < img.rows; i++) 47 if (xtag[i] == 1) dwi = i, upi = min(upi, i); 48 pec[pec_num].create(dwi - upi + 1, ri - li + 1, CV_8UC1); 49 for (int i = upi; i <= dwi; i++) 50 for (int j = li; j <= ri; j++) 51 pec[pec_num].at
(i - upi, j - li) = img_sav.at
(i, j); //裁出图片 52 new_sz = max(dwi - upi + 1, ri - li + 1); //使长宽相等,多余部分填充黑色 53 tmp_mat.create(new_sz, new_sz, CV_8UC1); 54 for (int i = 0; i < tmp_mat.rows; i++) 55 for (int j = 0; j < tmp_mat.cols; j++) 56 tmp_mat.at
(i, j) = 0; 57 if (pec[pec_num].rows < new_sz) 58 { 59 sz_bas = new_sz - pec[pec_num].rows, sz_bas /= 2; 60 for (int i = 0; i < pec[pec_num].rows; i++) 61 for (int j = 0; j < pec[pec_num].cols; j++) 62 tmp_mat.at
(sz_bas + i, j) = pec[pec_num].at
(i, j); 63 } 64 else 65 { 66 sz_bas = new_sz - pec[pec_num].cols, sz_bas /= 2; 67 for (int i = 0; i < pec[pec_num].rows; i++) 68 for (int j = 0; j < pec[pec_num].cols; j++) 69 tmp_mat.at
(i, sz_bas + j) = pec[pec_num].at
(i, j); 70 } 71 resize(tmp_mat, tmp_mat , Size(22, 22), 0, 0, INTER_LINEAR); //重新设定大小 72 // showPic(tmp_mat); 73 pec[pec_num].create(28, 28, CV_8UC1); 74 for (int i = 0; i < 28; i++) 75 for (int j = 0; j < 28; j++) 76 pec[pec_num].at
(i, j) = 0; //填充 77 for (int i = 3; i < 25; i++) 78 for (int j = 3; j < 25; j++) 79 pec[pec_num].at
(i, j) = ((int)tmp_mat.at
(i - 3, j - 3)); //留出边框 80 for (int i = 0; i < 28; i++) //加强图片 81 for (int j = 0; j < 28; j++) 82 pec[pec_num].at
(i, j) = chg((int)(pec[pec_num].at
(i, j))); 83 // showPic(pec[pec_num]); 84 } 85 86 ImgCut(){} 87 88 ImgCut(Mat spl_img) 89 { 90 img = spl_img.clone(); 91 cvtColor(img, img, CV_BGR2GRAY); 92 if (img.at
(0, 0)>255 / 2) 93 { 94 for (int i = 0; i < img.rows; i++) 95 for (int j = 0; j < img.cols; j++) 96 img.at
(i, j) = 255 - img.at
(i, j); 97 } 98 // showPic(img); 99 img_sav = img.clone();100 int sz = sqrt(img_sav.rows*img_sav.cols) / 28 / 2;101 if ((sz & 1) == 0) sz ^= 1;102 for (int i = 0; i < img_sav.rows; i++)103 for (int j = 0; j < img_sav.cols; j++)104 {105 double tmp = 0, cnt = 0;106 int pp, qq;107 for (int p = 0; p < sz; p++)108 for (int q = 0; q < sz; q++)109 {110 pp = i + p - sz / 2, qq = j + q - sz / 2;111 if (pp < 0 || qq < 0 || pp >= img_sav.rows || qq >= img_sav.cols)112 continue;113 cnt++, tmp += img.at
(pp, qq);114 }115 img.at
(i, j) = (int)(tmp / cnt + 0.5);116 }117 // img = img_sav.clone();118 // showPic(img);119 threshold(img, img, 255 / 3 , 255, THRESH_BINARY);120 Mat mat_tmp = img.clone();121 int tmp, tmpi, tmpj;122 // showPic(mat_tmp);123 for (int i = 0; i < img.rows; i++)124 for (int j = 0; j < img.cols; j++)125 {126 tmp = 0;127 for (int p = 0; p < 5; p++)128 for (int q = 0; q < 5; q++)129 {130 tmpi = i + p - 2; tmpj = j + q - 2;131 if (tmpi < 0 || tmpi >= img.rows || tmpj < 0 || tmpj >= img.cols)132 continue;133 tmp += ((int)mat_tmp.at
(tmpi, tmpj) == 255);134 }135 img.at
(i, j) = (int)((tmp>3) * 255);136 }137 // showPic(img);138 memset(tag, 0, sizeof(tag)); pec_num = 0;139 for (int i = 0; i < img.rows; i++)140 for (int j = 0; j < img.cols; j++)141 tag[j] |= ((int)img.at
(i, j) == 255);142 for (int j = 0; j < img.cols; j++)143 if (tag[j] == 1)144 {145 int tmpj = j;146 while (tmpj + 1 < img.cols && tag[tmpj] == 1)147 tmpj++;148 deal(j, tmpj), ++pec_num;149 j = tmpj;150 }151 }152 153 154 };155 156 157 #endif
ImgCut.h

 

转载于:https://www.cnblogs.com/FxxL/p/8410646.html

你可能感兴趣的文章
-UVa10935题:Trowing cards away1解答及简单分析
查看>>
Flex事件(转)
查看>>
Nodejs模块化
查看>>
一个应用商店的展示
查看>>
GPUImage的简单使用
查看>>
VIM Pal 1.1.0 发布,VIM 文件树列表
查看>>
Exam 70-762 Developing SQL Databases
查看>>
关于排列问题的一系列归类
查看>>
(转)php语法(符号用法)
查看>>
Delphi Post登陆Delphi盒子论坛源码
查看>>
wcf自定义绑定
查看>>
MongoDB数据文件内部结构(转载)
查看>>
IntelliJ IDEA常用统一设置(Linux/Mac/Windows)
查看>>
JSP发送电子邮件
查看>>
button变成href (即按钮超链效果)
查看>>
SQL语句-创建索引
查看>>
通用服务器架构
查看>>
安卓(Android)手机Flash Player官方下载地址
查看>>
laravel Faker-1.faker假数据
查看>>
python3安装Fabric模块
查看>>