scikit-learnのSVMで手書きの数字画像認識する方法。正解率をあげるコツを考えてみた。

▼この記事をSNSでシェアする▼

スポンサーリンク

スポンサーリンク

1. 今回やりたいこと

こんな感じの手書きの数字を「2」だと判定させたい。これが今回のゴール。

そのために何が必要かというと、

  • scikit-learnのdatasets数字画像と、正しい数字の解答から学習データを作成する
  • そのデータをもとに手書き画像をを使い、判定する

この2点になります。

いわゆる教師あり学習ってやつです。

数字の画像と正しい数字の組み合わせ(教師)を用意した上で、そのデータに手書き画像に適用し数字を予想するというものになります。

学習させるのに最適なアルゴリズムを探すのに、役に立つチートシートを見ると、赤い四角のところにLinearSVC()(線形クラス分類)があります。

  1. start
  2. >50:yes
  3. predicting a category:yes
  4. do you have a label data:yes(数字画像に対し正解の数字がラベルされている)
  5. < 100k samples:yes

LinearSVC()は分類や回帰や異常値検出に使用されるサポートベクトルマシン(SVM)、教師あり学習方法のセットでの一つで、線形サポートベクトル分類と呼ばれる。

公式リファレンス

http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html#sklearn.svm.LinearSVC

深いところは理解してないので、勉強しなきゃ・・・

2. ぼくの利用している環境

おさらいしておきましょう。

Mac OS High Sierra10.13.6に下記のものを利用しています。

  • Atom(テキストエディタ+ターミナル付き)
  • Homebrew(パッケージマネージャ)
  • pyenv(Pythonのバージョン管理ツール)
  • pyenv-virtualenv(仮想実行環境用のツール)
  • Anaconda(Numpyとか便利なライブラリがセットで便利)
  • OpenCV:画像・動画処理のライブラリ
  • MeCab:日本語形態素解析のライブラリ
  • Gensim:自然言語処理のライブラリ
  • TensorFlow / Keras:ディープラーニングライブラリ

ここらへんを入れて、Jupyter Notebookを使っています。

機械学習のために上のセットアップをした時の手順はこちらを参照してください。

今回使うのはOpenCVです。
下記コマンドを使えばインストールできます。

3. プログラム

3.1  数字画像の学習データ生成プログラム

scikit-learnから、datasetsの数字データ・SVMのLinearSVC()学習モデルライブラリをインポートし、それをfitメソッドで学習したclfに格納されている学習モデルデータを、joblibという学習データ管理用のモジュールでdigits.pklというファイルに保存します。

ml_digits.ipynb

3.2 学習データによる数字判定プログラム

学習データ(digits.pkl)をjoblibモジュールを利用し読み込み、clfという変数に格納する。そして、手書きの数字のデータを判定しやすい白字(背景が黒色)を使うために、グレースケールにして白黒反転させます。

→datasetsにある数字の画像が白字(背景が黒色)であるため。

そのあとにclfのモデルを用いてpredictメソッドで手書きの数字を予想します。

predict_tegaki.ipynb

4. プログラムの実行結果

上のコメントに文字の太さが云々ってあったのでネタバレ感ありますが、実際にプログラムを実行すると、以下のようになりました。

accuracy_scoreを使って判定した結果:約95%の正解率となりました。

そしてdigits.pklが生成されていることがわかります。

次にml_digits.ipynbとdigits.pklと同じ階層でpredict_tegaki.ipynbを実行します。

今回はぼくがGimpを使って書いた手書きの数字を判定してみました。

2と6は正しく判定できましたが、7は正しく認識されませんでした。

  

すべてサイズを揃え、400*400pxに揃えましたが、2,6は50pxの太さの線で、7は20pxの太さの線で書きました。

5. 正しく判定するために

やはりサンプルの全体と文字バランスに近いように手書きの数字を描くのが良いと思います。

scikit-learnからダウンロードするdatasetsはこんな感じです。

8*8pxセルのデータになっていて、やや文字が太いです。

上の画像のイメージに近い数字を描くように心がけるのが良いかもしれないです。

スポンサーリンク

▼この記事をSNSでシェアする▼

フォローする

メニュー・主な記事カテゴリ

おすすめ特集!




「ゆとり鳥日記」について
ITを中心に関心の赴くままに好きなように書いていく雑記ブログ!管理人が二人います。
◆フクロウ(19卒就活生)
◆トンビ(社会人1年目SE)

詳しいプロフィール
お仕事の依頼・ご要望

ゆとり鳥日記をBTCで応援する