ひきぷろのプログラミング日記

プログラミングの日記です。

(C#) MNISTの画像データを読み込むクラス

ニューラルネットワークのテストをしようと思って、MNISTの画像データを読み込むクラスを C# で作ってみました。

作ったものをコピペします。

ビッグエンディアンで数値が扱える BinaryReader

MNIST の画像データは、ビッグエンディアン形式で保存されています。
普通の BinaryReader はリトルエンディアンしか扱えないので、ビッグエンディアンを扱えるようにしました。

using System;
using System.Linq;
using System.Text;
using System.IO;

namespace NNTest {
	class BinaryReaderBE : BinaryReader {
		public BinaryReaderBE(Stream input)
			: base(input) {
		}
		public BinaryReaderBE(Stream input, Encoding encoding)
			: base(input, encoding) {
		}

		public override short ReadInt16() {
			return _ToBigEndian(base.ReadInt16());
		}
		public override int ReadInt32() {
			return _ToBigEndian(base.ReadInt32());
		}
		public override long ReadInt64() {
			return _ToBigEndian(base.ReadInt64());
		}
		public override ushort ReadUInt16() {
			return _ToBigEndian(base.ReadUInt16());
		}
		public override uint ReadUInt32() {
			return _ToBigEndian(base.ReadUInt32());
		}
		public override ulong ReadUInt64() {
			return _ToBigEndian(base.ReadUInt64());
		}
		public override float ReadSingle() {
			return _ToBigEndian(base.ReadSingle());
		}
		public override double ReadDouble() {
			return _ToBigEndian(base.ReadDouble());
		}
		public override decimal ReadDecimal() {
			throw new NotImplementedException();
		}

		private short _ToBigEndian(short value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToInt16(bytes, 0);
		}

		private ushort _ToBigEndian(ushort value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToUInt16(bytes, 0);
		}

		private int _ToBigEndian(int value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToInt32(bytes, 0);
		}

		private uint _ToBigEndian(uint value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToUInt32(bytes, 0);
		}

		private long _ToBigEndian(long value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToInt64(bytes, 0);
		}

		private ulong _ToBigEndian(ulong value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToUInt64(bytes, 0);
		}

		private float _ToBigEndian(float value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToSingle(bytes, 0);
		}

		private double _ToBigEndian(double value) {
			byte[] bytes = BitConverter.GetBytes(value);
			bytes = _ReverseBytes(bytes);
			return BitConverter.ToDouble(bytes, 0);
		}

		private byte[] _ReverseBytes(byte[] bytes) {
			if (bytes == null) {
				return null;
			}
			return bytes.Reverse().ToArray();
		}
	}
}
画像データのローダ

上の、 BinaryReaderBE を使って画像データを読み込むためのクラスです。
見たまんまだと思うので、すぐに使えると思います。

using System;
using System.IO;
using System.Drawing;
using System.Collections.Generic;
using System.Drawing.Imaging;
using System.Runtime.InteropServices;
using System.Text;

namespace NNTest {
	/// <summary>
	/// MNIST の画像をロードするためのクラス.
	/// http://yann.lecun.com/exdb/mnist/
	/// </summary>
	class MNistImageLoader {
		/// <summary>
		/// 0x0000 から始まるマジックナンバー.
		/// 0x00000803 (2051) が入る.
		/// </summary>
		public int magicNumber;

		/// <summary>
		/// 画像の数.
		/// </summary>
		public int numberOfImages;

		/// <summary>
		/// 画像の縦方向のサイズ.
		/// </summary>
		public int numberOfRows;

		/// <summary>
		/// 画像の横方向のサイズ.
		/// </summary>
		public int numberOfColumns;

		/// <summary>
		/// 画像の配列.
		/// Bitmap 形式で取得する場合は GetBitmap(index) を使用する.
		/// </summary>
		public List<byte[]> bitmapList;

		/// <summary>
		/// コンストラクタ.
		/// </summary>
		public MNistImageLoader() {
			bitmapList = new List<byte[]>();
		}

		/// <summary>
		/// MNIST のデータをロードする.
		/// 失敗した時は null を返す.
		/// </summary>
		/// <param name="path">画像データのパス.</param>
		/// <returns></returns>
		public static MNistImageLoader Load(string path) {
			// ファイルが存在しない
			if (File.Exists(path) == false) {
				return null;
			}

			// バイト配列を分解する
			FileStream stream = new FileStream(path, FileMode.Open);
			BinaryReaderBE reader = new BinaryReaderBE(stream);

			MNistImageLoader loader = new MNistImageLoader();
			loader.magicNumber = reader.ReadInt32();
			loader.numberOfImages = reader.ReadInt32();
			loader.numberOfRows = reader.ReadInt32();
			loader.numberOfColumns = reader.ReadInt32();

			int pixelCount = loader.numberOfRows * loader.numberOfColumns;
			for (int i = 0; i < loader.numberOfImages; i++) {
				byte[] pixels = reader.ReadBytes(pixelCount);
				loader.bitmapList.Add(pixels);
			}

			reader.Close();
			return loader;
		}

		/// <summary>
		/// 引数で指定されたインデックス番号の画像を Bitmap 形式で取得する.
		/// 失敗した場合は null を返す.
		/// </summary>
		/// <param name="index">画像のインデックス番号.</param>
		/// <returns></returns>
		public Bitmap GetBitmap(int index) {
			// 範囲チェック
			if (index < 0 || index >= bitmapList.Count) {
				return null;
			}

			// Bitmap 画像を作成する
			Bitmap bitmap = new Bitmap(
				numberOfColumns,
				numberOfRows,
				PixelFormat.Format24bppRgb
			);
			BitmapData bitmapData = bitmap.LockBits(
				new Rectangle(0, 0, bitmap.Width, bitmap.Height),
				ImageLockMode.ReadWrite,
				bitmap.PixelFormat
			);


			byte[] pixels = bitmapList[index];
			IntPtr intPtr = bitmapData.Scan0;
			for (int y = 0; y < numberOfRows; y++) {
				int offsetY = bitmapData.Stride * y;
				for (int x = 0; x < numberOfColumns; x++) {
					byte b = pixels[x + y * numberOfColumns];
					// 次の行をコメントアウトすると白黒反転します
					b = (byte)~b;
					int offset = x * 3 + offsetY;
					Marshal.WriteByte(intPtr, offset + 0, b);
					Marshal.WriteByte(intPtr, offset + 1, b);
					Marshal.WriteByte(intPtr, offset + 2, b);
				}
			}

			bitmap.UnlockBits(bitmapData);
			return bitmap;
		}

		/// <summary>
		/// デバッグ用.
		/// </summary>
		/// <returns></returns>
		public override string ToString() {
			StringBuilder stringBuilder = new StringBuilder();
			stringBuilder.Append(GetType().Name);
			stringBuilder.AppendLine();
			stringBuilder.AppendFormat("\tmagicNumber: 0x{0:X8}", magicNumber);
			stringBuilder.AppendLine();
			stringBuilder.AppendFormat("\tnumberOfImages: {0}", numberOfImages);
			stringBuilder.AppendLine();
			stringBuilder.AppendFormat("\tnumberOfRows: {0}", numberOfRows);
			stringBuilder.AppendLine();
			stringBuilder.AppendFormat("\tnumberOfColumns: {0}", numberOfColumns);
			return stringBuilder.ToString();
		}
	}
}
使い方
string path = "../../data/train-images.idx3-ubyte";
MNistImageLoader loader = MNistImageLoader.Load(path);
Console.WriteLine("loader: {0}", loader.ToString());

// ビットマップ形式で画像を取得する
Bitmap bitmap = loader.GetBitmap(0);

// バイト配列で画像を取得する
byte[] bytes = loader.bitmapList[0];