With a gpu batch size of 512 and doing 20 mini-batches at a time (10,240 samples) it takes less than 2 hours. This is actually doing 2 epochs per batch (not full epochs over the entire sample set at once). On my GTX-1070 gpu each batch of 512 epoch takes about 3 seconds. The vast majority of the time is parsing the pgn input which is fed directly into the training. Most other approaches convert the pgn to another binary input plane format, but I have limited disk space (yes, I know disks are relatively cheap), and I also want to be able to change the input without having to re-process everything. Of course, I end up reprocessing it every run, but that's how I'm currently doing it.
I'm not using just MSE but loss as a weighted combination of value head MSE and policy head categorical cross-entropy. After one run (again with 2 epochs per batch) the accuracy is about 70% and loss goes down from 8.6 to 1.1 Finally, the actual numbers are less important than the trend (although I would like to get to greater accuracy).
I am still fiddling with the learning rate parameters and using 1 cycle LR (along with the 2 epochs per batch) from here:
https://medium.com/oracledevs/lessons-f ... cfcbe4ca9a
For more net info see:
https://github.com/Zeta36/chess-alpha-zero
FYI, the net looks like this:
Code: Select all
>>> model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 18, 8, 8) 0
__________________________________________________________________________________________________
input_conv-3-32 (Conv2D) (None, 32, 8, 8) 5184 input_1[0][0]
__________________________________________________________________________________________________
input_batchnorm (BatchNormaliza (None, 32, 8, 8) 128 input_conv-3-32[0][0]
__________________________________________________________________________________________________
input_relu (Activation) (None, 32, 8, 8) 0 input_batchnorm[0][0]
__________________________________________________________________________________________________
res1_conv1-3-32 (Conv2D) (None, 32, 8, 8) 9216 input_relu[0][0]
__________________________________________________________________________________________________
res1_batchnorm1 (BatchNormaliza (None, 32, 8, 8) 128 res1_conv1-3-32[0][0]
__________________________________________________________________________________________________
res1_relu1 (Activation) (None, 32, 8, 8) 0 res1_batchnorm1[0][0]
__________________________________________________________________________________________________
res1_conv2-3-32 (Conv2D) (None, 32, 8, 8) 9216 res1_relu1[0][0]
__________________________________________________________________________________________________
res1_batchnorm2 (BatchNormaliza (None, 32, 8, 8) 128 res1_conv2-3-32[0][0]
__________________________________________________________________________________________________
res1_add (Add) (None, 32, 8, 8) 0 input_relu[0][0]
res1_batchnorm2[0][0]
__________________________________________________________________________________________________
res1_relu2 (Activation) (None, 32, 8, 8) 0 res1_add[0][0]
__________________________________________________________________________________________________
res2_conv1-3-32 (Conv2D) (None, 32, 8, 8) 9216 res1_relu2[0][0]
__________________________________________________________________________________________________
res2_batchnorm1 (BatchNormaliza (None, 32, 8, 8) 128 res2_conv1-3-32[0][0]
__________________________________________________________________________________________________
res2_relu1 (Activation) (None, 32, 8, 8) 0 res2_batchnorm1[0][0]
__________________________________________________________________________________________________
res2_conv2-3-32 (Conv2D) (None, 32, 8, 8) 9216 res2_relu1[0][0]
__________________________________________________________________________________________________
res2_batchnorm2 (BatchNormaliza (None, 32, 8, 8) 128 res2_conv2-3-32[0][0]
__________________________________________________________________________________________________
res2_add (Add) (None, 32, 8, 8) 0 res1_relu2[0][0]
res2_batchnorm2[0][0]
__________________________________________________________________________________________________
res2_relu2 (Activation) (None, 32, 8, 8) 0 res2_add[0][0]
__________________________________________________________________________________________________
res3_conv1-3-32 (Conv2D) (None, 32, 8, 8) 9216 res2_relu2[0][0]
__________________________________________________________________________________________________
res3_batchnorm1 (BatchNormaliza (None, 32, 8, 8) 128 res3_conv1-3-32[0][0]
__________________________________________________________________________________________________
res3_relu1 (Activation) (None, 32, 8, 8) 0 res3_batchnorm1[0][0]
__________________________________________________________________________________________________
res3_conv2-3-32 (Conv2D) (None, 32, 8, 8) 9216 res3_relu1[0][0]
__________________________________________________________________________________________________
res3_batchnorm2 (BatchNormaliza (None, 32, 8, 8) 128 res3_conv2-3-32[0][0]
__________________________________________________________________________________________________
res3_add (Add) (None, 32, 8, 8) 0 res2_relu2[0][0]
res3_batchnorm2[0][0]
__________________________________________________________________________________________________
res3_relu2 (Activation) (None, 32, 8, 8) 0 res3_add[0][0]
__________________________________________________________________________________________________
res4_conv1-3-32 (Conv2D) (None, 32, 8, 8) 9216 res3_relu2[0][0]
__________________________________________________________________________________________________
res4_batchnorm1 (BatchNormaliza (None, 32, 8, 8) 128 res4_conv1-3-32[0][0]
__________________________________________________________________________________________________
res4_relu1 (Activation) (None, 32, 8, 8) 0 res4_batchnorm1[0][0]
__________________________________________________________________________________________________
res4_conv2-3-32 (Conv2D) (None, 32, 8, 8) 9216 res4_relu1[0][0]
__________________________________________________________________________________________________
res4_batchnorm2 (BatchNormaliza (None, 32, 8, 8) 128 res4_conv2-3-32[0][0]
__________________________________________________________________________________________________
res4_add (Add) (None, 32, 8, 8) 0 res3_relu2[0][0]
res4_batchnorm2[0][0]
__________________________________________________________________________________________________
res4_relu2 (Activation) (None, 32, 8, 8) 0 res4_add[0][0]
__________________________________________________________________________________________________
value_conv-1-1 (Conv2D) (None, 4, 8, 8) 128 res4_relu2[0][0]
__________________________________________________________________________________________________
policy_conv-1-2 (Conv2D) (None, 8, 8, 8) 256 res4_relu2[0][0]
__________________________________________________________________________________________________
value_batchnorm (BatchNormaliza (None, 4, 8, 8) 16 value_conv-1-1[0][0]
__________________________________________________________________________________________________
policy_batchnorm (BatchNormaliz (None, 8, 8, 8) 32 policy_conv-1-2[0][0]
__________________________________________________________________________________________________
value_relu (Activation) (None, 4, 8, 8) 0 value_batchnorm[0][0]
__________________________________________________________________________________________________
policy_relu (Activation) (None, 8, 8, 8) 0 policy_batchnorm[0][0]
__________________________________________________________________________________________________
value_flatten (Flatten) (None, 256) 0 value_relu[0][0]
__________________________________________________________________________________________________
policy_flatten (Flatten) (None, 512) 0 policy_relu[0][0]
__________________________________________________________________________________________________
value_dense (Dense) (None, 256) 65792 value_flatten[0][0]
__________________________________________________________________________________________________
policy_out (Dense) (None, 1968) 1009584 policy_flatten[0][0]
__________________________________________________________________________________________________
value_out (Dense) (None, 1) 257 value_dense[0][0]
==================================================================================================
Total params: 1,156,129
Trainable params: 1,155,529
Non-trainable params: 600
__________________________________________________________________________________________________
>>>