• ベストアンサー

BP学習に使う学習データについて

ここ http://mars.elcom.nitech.ac.jp/java-cai/neuro/menu.html を参考にニューラルネットワークのBP学習のプログラムを作成しているのですが、 学習データについて疑問があります。 学習に使用する学習データをプログラムでは乱数を生成して作っている のですが、学習データの生成に際して、データの分散や平均の大きさを 考えるべきなのでしょうか? また考えるべきだとしたら、どのように評価したらよいのでしょうか?

質問者が選んだベストアンサー

  • ベストアンサー
  • tomo316
  • ベストアンサー率35% (51/142)
回答No.1

データの分散や平均の大きさを考えると、統計ですね。 参考まで。 /* Back-Propagation Program ver1.01 */ #include<stdio.h> #include<math.h> #include<stdlib.h> #include<time.h> #define INPUT 2 #define HIDDEN 4 #define OUTPUT 1 #define PATTERN 4 #define PR 100 #define MAX_T 10000 #define eta 2.4 #define eps 1.0e-4 #define alpha 0.8 #define beta 0.8 #define W0 0.5 double xi[INPUT+1],v[HIDDEN+1],o[OUTPUT],zeta[OUTPUT]; double w1[HIDDEN][INPUT+1],w2[OUTPUT][HIDDEN+1]; double d_w1[HIDDEN][INPUT+1],d_w2[OUTPUT][HIDDEN+1]; double pre_dw1[HIDDEN][INPUT+1],pre_dw2[OUTPUT][HIDDEN+1]; double data[PATTERN][INPUT],t_data[PATTERN][OUTPUT]; void load_data(char *filename); void back_propagation(); void w_init(); double ranran(); void dw_init(); void xi_set(long int t, int p); void forward(long int t); void backward(); double calc_error(); void modify_w(); void w_print(); double sigmoid(double u); main(int argc, char *argv[]) { load_data(*++argv); back_propagation(); w_print(); } void load_data(char *filename) { int p,k,i; double value; FILE *fp; fp = fopen(filename,"r"); if( fp == NULL ) { fprintf(stderr,"File Open Error!\n"); exit(0); } for( p=0 ; p < PATTERN ; p++ ){ for( k=0 ; k < INPUT ; k++ ){ fscanf(fp," %lf",&value); data[p][k] = value; } for( i=0 ; i < OUTPUT ; i++ ){ fscanf(fp," %lf",&value); t_data[p][i] = value; } } fclose(fp); printf("Input Desired\n"); for( p=0 ; p<PATTERN ; p++ ){ printf("{"); for( k=0 ; k<INPUT ; k++ ) printf(" %.0lf,",data[p][k]); printf("} -> {"); for( i=0 ; i<OUTPUT ; i++) printf("%.0lf,",t_data[p][i]); printf("}\n"); } putchar('\n'); } void back_propagation() { long int t; int p; double E,Esum; w_init(); for( t=0 ; t < MAX_T ; t++ ){ dw_init(); for( p=0, Esum=0 ; p < PATTERN ; p++ ){ xi_set(t,p); forward(t); backward(); Esum += calc_error(); } modify_w(); E = Esum / (OUTPUT * PATTERN); if( t%PR == 0 ) printf("%ld %e\n",t,E); if( E < eps ) break; } printf("\nTime = %ld",t); if( t == MAX_T ) printf(" (MAX) You must retry!"); putchar('\n'); for( p=0 ; p < PATTERN ; p++ ){ xi_set(0,p); forward(0); } printf("E = %e\n",E); } void w_init() { int i,j,k; long time_t; time_t = time(NULL); //srand48(time_t); srand(time_t); for( j=0 ; j < HIDDEN ; j++ ) for( k=0 ; k <INPUT+1 ; k++ ){ w1[j][k] = ranran(); d_w1[j][k] = 0.0; } for( i=0 ; i < OUTPUT ; i++ ) for( j=0 ; j < HIDDEN+1 ; j++ ){ w2[i][j] = ranran(); d_w2[i][j] = 0.0; } } double ranran() { double r; //r = drand48(); r = rand(); r = r * 2*W0 - W0; return r; } void dw_init() { int i,j,k; for( j=0 ; j < HIDDEN ; j++) for( k=0 ; k < INPUT+1 ; k++ ){ pre_dw1[j][k] = d_w1[j][k]; d_w1[j][k] = 0.0; } for( i=0 ; i <OUTPUT ; i++ ) for( j=0 ; j < HIDDEN+1 ; j++ ){ pre_dw2[i][j] = d_w2[i][j]; d_w2[i][j] = 0.0; } } void xi_set(long int t, int p) { int i,k; if( t%PR == 0 ) printf("Input "); for( k=0 ; k < INPUT ; k++ ){ xi[k] = data[p][k]; if( t%PR == 0 ) printf(" %.0lf ",xi[k]); } xi[INPUT] = 1.0; if( t%PR == 0 ) putchar('('); for( i=0 ; i < OUTPUT ; i++ ){ zeta[i] = t_data[p][i]; if( t%PR == 0 ) printf(" %.0lf ",zeta[i]); } if( t%PR == 0 ) printf(")\n"); } void forward(long int t) { int i,j,k; double sum; for( j=0 ; j < HIDDEN ; j++ ){ for( k=0, sum=0 ; k < INPUT+1 ; k++ ) sum += xi[k] * w1[j][k]; v[j] = sigmoid(sum); } if( t%PR == 0 ) printf("Output "); v[HIDDEN] = 1.0; for( i=0 ; i < OUTPUT ; i++ ){ for( j=0, sum=0 ; j < HIDDEN+1 ; j++ ) sum += v[j] * w2[i][j]; o[i] = sigmoid(sum); if( t%PR == 0 ) printf(" %.4lf",o[i]); } if(t %PR == 0 ) putchar('\n'); } void backward() { int i,j,k; double delta2[OUTPUT],delta1[HIDDEN+1],sum; for( i=0 ; i < OUTPUT ; i++ ) delta2[i] = beta * o[i] * (1-o[i]) * (zeta[i]-o[i]); for( j=0 ; j < HIDDEN ; j++){ for( i=0, sum=0 ; i < OUTPUT ; i++ ) sum += w2[i][j] * delta2[i]; delta1[j] = beta * v[j] * (1-v[j]) * sum; } for( i=0 ; i < OUTPUT ; i++ ) for( j=0 ; j < HIDDEN+1 ; j++) d_w2[i][j] += delta2[i] * v[j]; for( j=0 ; j < HIDDEN ; j++ ) for( k=0 ; k < INPUT+1 ; k++ ) d_w1[j][k] += delta1[j] * xi[k]; } double calc_error() { double E=0; int i; for( i=0 ; i < OUTPUT ; i++ ) E += (zeta[i]-o[i]) * (zeta[i]-o[i]); return E; } void modify_w() { int i,j,k; for( i=0 ; i < OUTPUT ; i++ ) for( j=0 ; j < HIDDEN+1 ; j++ ){ d_w2[i][j] = eta * d_w2[i][j] + alpha * pre_dw2[i][j]; w2[i][j] = w2[i][j] + d_w2[i][j]; } for( j=0 ; j < HIDDEN ; j++) for( k=0 ; k < INPUT+1 ; k++ ){ d_w1[j][k] = eta * d_w1[j][k] + alpha * pre_dw1[j][k]; w1[j][k] = w1[j][k] + d_w1[j][k]; } } void w_print() { int i,j,k; printf("Weight\n"); for(j =0 ; j < HIDDEN ; j++ ){ printf("w1[%d]={",j); for( k=0 ; k < INPUT ; k++ ){ if( k != 0 ) putchar(','); printf("%.6lf",w1[j][k]); } printf("} theta1[%d]=%.6lf\n",j,w1[j][k]); } for( i=0 ; i < OUTPUT ; i++ ){ printf("w2[%d]={",i); for( j=0 ; j < HIDDEN ;j ++ ){ if( j != 0 ) putchar(','); printf("%.6lf",w2[i][j]); } printf("} theta2[%d]=%.6lf\n",i,w2[i][j]); } } double sigmoid(double u) { return 1.0 / (1.0+exp(-beta*u)); }

marucha
質問者

補足

ファイルから読み込むデータはデータ作成の仕方によって良いデータ、悪いデータ というのができてしまうと思うのですが、良いデータ、悪いデータの見極め方はないでしょうか?

その他の回答 (1)

  • tomo316
  • ベストアンサー率35% (51/142)
回答No.2

このプログラムを例にすると、エラー値が出ます。 ゼロに近ければ成功です。 入力 x1 x2 xor 0 0 0 0 1 1 1 0 1 1 1 0 出力 デフォールトでは500回に1回学習の途中結果が次のように表示されます。 Input 1 1 ( 0 ) <- パターン0の入力と教師信号 Output 0.0127 <- その入力を入れたとき実際の出力 Input 1 0 ( 1 ) <- パターン1の入力と教師信号 Output 0.9863 <- その入力を入れたとき実際の出力 Input 0 1 ( 1 ) <- パターン2の入力と教師信号 Output 0.9856 <- その入力を入れたとき実際の出力 Input 0 0 ( 0 ) <- パターン3の入力と教師信号 Output 0.0134 <- その入力を入れたとき実際の出力 500 9.205321e-05 <- 現在の学習回数とエラー値 学習が終了すると次のように表示されます。 Time = 2277 <- かかった学習の回数 Input 1 1 ( 0 ) <- パターン0の入力と教師信号 Output 0.0041 <- その入力を入れたときの出力 Input 1 0 ( 1 ) <- 以下各パターンについて同じ Output 0.9958 Input 0 1 ( 1 ) Output 0.9949 Input 0 0 ( 0 ) Output 0.0044 E = 9.998749e-06 <- 最終的なエラー値(ゼロに近い値なはず) Weight <- 学習終了後の重みと閾値 w1[0]={3.914727,-3.845814} theta1[0]=-2.108753 w1[1]={3.554649,-3.407200} theta1[1]=1.706878 w2[0]={5.865372,-5.808394} theta2[0]=2.825594

関連するQ&A