AForge.NETのBackPropagationLearningをマルチスレッドに書き換える

AForge.NETで遊んでいる。Back Propagation Learning事態は20年以上前からあるのだが、当時は少ないメモリをやり繰りしながら一から実装するほかなかった。お手軽にライブラリでできるって素晴らしいよね。ただ残念なことにAForge.NETのBack Propagation LearningはMulti Threadに対応しないので、Parallel.Forを使って修正してみた。
マルチスレッドに対応させるための書き換えも.Net Framework 4.xではとてもお手軽にできる。今時マルチスレッドに動作しないのってすごく罪だよね。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using AForge.Neuro;
using AForge.Neuro.Learning;

namespace RemoveRule
{
    class MultiThreadBackPropagationLearning : BackPropagationLearning
    {
        // network to teach
        private ActivationNetwork network;

        // learning rate
        private double learningRate = 0.1;

        // momentum
        private double momentum = 0.0;

        // neuron's errors
        private double[][] neuronErrors = null;

        // weight's updates
        private double[][][] weightsUpdates = null;

        // threshold's updates
        private double[][] thresholdsUpdates = null;

        public new double  LearningRate
        {
            get { return learningRate; }
            set
            {
                learningRate = Math.Max(0.0, Math.Min(1.0, value));
            }
        }

        public new double Momentum
        {
            get { return momentum; }
            set
            {
                momentum = Math.Max(0.0, Math.Min(1.0, value));
            }
        }

        public MultiThreadBackPropagationLearning(ActivationNetwork network) : base(network)
        {
            this.network = network;

            // create error and deltas arrays
            neuronErrors = new double[network.Layers.Length][];
            weightsUpdates = new double[network.Layers.Length][][];
            thresholdsUpdates = new double[network.Layers.Length][];

            // initialize errors and deltas arrays for each layer
            for (int i = 0; i < network.Layers.Length; i++)
            {
                Layer layer = network.Layers[i];

                neuronErrors[i] = new double[layer.Neurons.Length];
                weightsUpdates[i] = new double[layer.Neurons.Length][];
                thresholdsUpdates[i] = new double[layer.Neurons.Length];

                // for each neuron
                for (int j = 0; j < weightsUpdates[i].Length; j++)
                {
                    weightsUpdates[i][j] = new double[layer.InputsCount];
                }
            }
        }

        public new double Run(double[] input, double[] output)
        {
            // compute the network's output
            network.Compute(input);

            // calculate network error
            double error = CalculateError(output);

            // calculate weights updates
            CalculateUpdates(input);

            // update the network
            UpdateNetwork();

            return error;

        }

        public new double RunEpoch(double[][] input, double[][] output)
        {
            double error = 0.0;

            // run learning procedure for all samples
            for (int i = 0; i < input.Length; i++)
            {
                error += Run(input[i], output[i]);
            }

            // return summary error
            return error;
        }

        private double CalculateError(double[] desiredOutput)
        {
            // current and the next layers
            Layer layer, layerNext;
            // current and the next errors arrays
            double[] errors, errorsNext;
            // error values
            double error = 0;
            // neuron's output value
            double output;
            // layers count
            int layersCount = network.Layers.Length;

            // assume, that all neurons of the network have the same activation function
            IActivationFunction function = (network.Layers[0].Neurons[0] as ActivationNeuron).ActivationFunction;

            // calculate error values for the last layer first
            layer = network.Layers[layersCount - 1];
            errors = neuronErrors[layersCount - 1];
            Parallel.For(0, layer.Neurons.Length, i =>
            {
                output = layer.Neurons[i].Output;

                // error of the neuron
                double e = desiredOutput[i] - output;

                // error multiplied with activation function's derivative
                double derivative;
                lock (function)
                {
                    derivative = function.Derivative2(output);
                }
                errors[i] = e * derivative;

                // squre the error and sum it
                Monitor.Enter(this);
                error += (e * e);
                Monitor.Exit(this);
            });

            // calculate error values for other layers
            for (int j = layersCount - 2; j >= 0; j--)
            {
                layer = network.Layers[j];
                layerNext = network.Layers[j + 1];
                errors = neuronErrors[j];
                errorsNext = neuronErrors[j + 1];

                // for all neurons of the layer
                Parallel.For(0, layer.Neurons.Length, i =>
                {
                    double sum = 0.0;

                    // for all neurons of the next layer
                    for (int k = 0; k < layerNext.Neurons.Length; k++)
                    {
                        sum += errorsNext[k] * layerNext.Neurons[k].Weights[i];
                    }

                    double derivative;
                    lock (function)
                    {
                        derivative = function.Derivative2(layer.Neurons[i].Output);
                    }
                    errors[i] = sum * derivative;
                });
            }

            // return squared error of the last layer divided by 2
            return error / 2.0;
        }

        private void CalculateUpdates(double[] input)
        {
            // current neuron
            Neuron neuron;

            // current and previous layers
            Layer layer, layerPrev;

            // layer's weights updates
            double[][] layerWeightsUpdates;

            // layer's thresholds updates
            double[] layerThresholdUpdates;

            // layer's error
            double[] errors;

            // error value
            // double           error;

            // 1 - calculate updates for the first layer
            layer = network.Layers[0];
            errors = neuronErrors[0];
            layerWeightsUpdates = weightsUpdates[0];
            layerThresholdUpdates = thresholdsUpdates[0];

            // cache for frequently used values
            double cachedMomentum = learningRate * momentum;
            double cached1mMomentum = learningRate * (1 - momentum);

            // for each neuron of the layer
            Parallel.For(0, layer.Neurons.Length, i =>
            {
                neuron = layer.Neurons[i];
                double cachedError = errors[i] * cached1mMomentum;
                double[] neuronWeightUpdates = layerWeightsUpdates[i];

                // for each weight of the neuron
                for (int j = 0; j < neuronWeightUpdates.Length; j++)
                {
                    // calculate weight update
                    neuronWeightUpdates[j] = cachedMomentum * neuronWeightUpdates[j] + cachedError * input[j];
                }

                // calculate treshold update
                layerThresholdUpdates[i] = cachedMomentum * layerThresholdUpdates[i] + cachedError;
            });


            // 2 - for all other layers
            for (int k = 1; k < network.Layers.Length; k++)
            {
                layerPrev = network.Layers[k - 1];
                layer = network.Layers[k];
                errors = neuronErrors[k];
                layerWeightsUpdates = weightsUpdates[k];
                layerThresholdUpdates = thresholdsUpdates[k];

                // for each neuron of the layer
                Parallel.For(0, layer.Neurons.Length, i =>
                {
                    neuron = layer.Neurons[i];
                    double cachedError = errors[i] * cached1mMomentum;
                    double[] neuronWeightUpdates = layerWeightsUpdates[i];

                    // for each synapse of the neuron
                    for (int j = 0; j < neuronWeightUpdates.Length; j++)
                    {
                        // calculate weight update
                        neuronWeightUpdates[j] = cachedMomentum * neuronWeightUpdates[j] + cachedError * layerPrev.Neurons[j].Output;
                    }

                    // calculate treshold update
                    layerThresholdUpdates[i] = cachedMomentum * layerThresholdUpdates[i] + cachedError;
                });
            }
        }

        private void UpdateNetwork()
        {
            // current layer
            Layer layer;
            // layer's weights updates
            double[][] layerWeightsUpdates;
            // layer's thresholds updates
            double[] layerThresholdUpdates;

            // for each layer of the network
            for (int i = 0; i < network.Layers.Length; i++)
            {
                layer = network.Layers[i];
                layerWeightsUpdates = weightsUpdates[i];
                layerThresholdUpdates = thresholdsUpdates[i];

                // for each neuron of the layer
                Parallel.For(0, layer.Neurons.Length, j =>
                {
                    ActivationNeuron neuron = layer.Neurons[j] as ActivationNeuron;
                    double[]  neuronWeightUpdates = layerWeightsUpdates[j];

                    // for each weight of the neuron
                    for (int k = 0; k < neuron.Weights.Length; k++)
                    {
                        // update weight
                        neuron.Weights[k] += neuronWeightUpdates[k];
                    }

                    // update treshold
                    neuron.Threshold += layerThresholdUpdates[j];
                });
            }
        }
    }
}

元々のBackPropagationLearningクラスでメンバ変数がprivate宣言されていたのがかなり残念。派生したクラスで挙動をオーバーライト出来ないので、丸ごと同じクラスを実装する他に無いのだ。こうなると、そもそもBackPropagationLearningを継承する必要すらなかったかもしれない。

Download MultiThreadBackPropagationLearning.zip

「AForge.NETのBackPropagationLearningをマルチスレッドに書き換える」への1件のフィードバック

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です