Don't know yet if it is already faster then without Winograd. To be optimized/rewritten
Code: Select all
...
var winoGradSliceFactors = new List<(double, double)>();
var wEnum = Weight[d].GetEnumerator();
wEnum.Reset();
for (int i = 0; i < Weight[d].nElements(); i += 3)
{
wEnum.MoveNext();
var w0 = wEnum.Current;
wEnum.MoveNext();
var w1 = wEnum.Current;
wEnum.MoveNext();
var w2 = wEnum.Current;
var winGrad1Factor = (w0 + w1 + w2) / 2;
var winGrad2Factor = (w0 - w1 + w2) / 2;
var tuple = (winGrad1Factor, winGrad2Factor);
winoGradSliceFactors.Add(tuple);
}
winoGradBlockFactors.Add(winoGradSliceFactors);
….
// computes innerproduct x * w for two consecutive windows x in a slice (window(l,m), window(l, m+ 1))
// l : row number
// m: column number
// d: depth or position of slice in block/set of slices to be evaluated
//
(double s1, double s2) CombinedInnerProduct(int d, int l, int m, List<(double, double)> winoGradFactors)
{
Debug.Assert(m + 1 < Start.Height - F + 1 + Padding);
var slice = Start.SliceList[d];
var w = Weight[d];
var wEnum = w.GetEnumerator();
var winoGradFactorsEnum = winoGradFactors.GetEnumerator();
wEnum.Reset(); // essential
double sum1 = 0;
double sum2 = 0;
for (int i = 0; i <= F - 1; i++)
{
if (l + i >= 0 && l + i < Start.Width)
{
for (int j = 0; j <= F - 1; j += 3)
{
var w0 = wEnum.MoveNext() ? wEnum.Current : 0;
var w1 = wEnum.MoveNext() ? wEnum.Current : 0;
var w2 = wEnum.MoveNext() ? wEnum.Current : 0;
var winoGradfactorPair = winoGradFactorsEnum.MoveNext() ? winoGradFactorsEnum.Current : (0, 0);
if (m + j + 3 >= 0 && m + j < Start.Height)
{
var sliceRow = slice.Value[l + i];
var u0 = (m + j >= 0) ? sliceRow[m + j] : 0;
var u1 = (m + j + 1 >= 0 && m + j + 1 < Start.Height) ? sliceRow[m + j + 1] : 0;
var u2 = (m + j + 2 >= 0 && m + j + 2 < Start.Height) ? sliceRow[m + j + 2] : 0;
var u3 = (m + j + 3 < Start.Height) ? sliceRow[m + j + 3] : 0;
var m1 = (u0 - u2) * w0;
var m2 = (u1 + u2) * winoGradfactorPair.Item1;
var m3 = (u2 - u1) * winoGradfactorPair.Item2;
var m4 = (u1 - u3) * w2;
Debug.Assert(Math.Abs(m1 + m2 + m3 - (u0 * w0 + u1 * w1 + u2 * w2)) <= 1E-7);
Debug.Assert(Math.Abs(m2 - m3 - m4 - (u1 * w0 + u2 * w1 + u3 * w2)) <= 1E-7);
sum1 += m1 + m2 + m3;
sum2 += m2 - m3 - m4;
}
}
}
else
{
for (int j = 0; j <= F - 1; j += 3)
{
wEnum.MoveNext();
wEnum.MoveNext();
wEnum.MoveNext();
winoGradFactorsEnum.MoveNext();
}
}
}
return (sum1, sum2);
}
https://arxiv.org/pdf/1509.09308.pdf