public static void Main()
var arr1 = new float[,] {{1,2,3,4,5},{2,3,4,5,6},{3,4,5,6,7}};
var matrix1 = new Matrix(arr1);
var arr2 = new float[,] {{2,3,4},{3,4,5},{4,5,6},{5,6,7},{6,7,8}};
var matrix2 = new Matrix(arr2);
Console.WriteLine(MatrixMultiplier.MultiplyBasic(matrix1,matrix2));
Console.WriteLine(matrix1*matrix2);
private float[,] _matrix;
public Matrix(int n, int m)
this._matrix = new float[n,m];
public Matrix(float[,] matrix)
_n = _matrix.GetLength(0);
_m = _matrix.GetLength(1);
public float this[int n, int m]
public static Matrix operator +(Matrix mat1, Matrix mat2)
if (mat1.n != mat2.n && mat1.m != mat2.m)
throw new ArgumentException();
var result = new Matrix(mat1.n, mat1.m);
for (int i = 0; i < mat1.n; i++)
for (int j = 0; j < mat1.m; j++)
result[i,j] = mat1[i,j]+mat2[i,j];
public static Matrix operator -(Matrix mat1, Matrix mat2)
if (mat1.n != mat2.n && mat1.m != mat2.m)
throw new ArgumentException();
var result = new Matrix(mat1.n, mat1.m);
for (int i = 0; i < mat1.n; i++)
for (int j = 0; j < mat1.m; j++)
result[i,j] = mat1[i,j]-mat2[i,j];
public static Matrix operator *(Matrix mat, float num)
var result = new Matrix(mat.n, mat.m);
for (int i = 0; i < mat.n; i++)
for (int j = 0; j < mat.m; j++)
result[i,j] = mat[i,j]*num;
public static Matrix operator *(Matrix mat1, Matrix mat2)
return MatrixMultiplier.MultiplyStrassen(mat1, mat2);
public override string ToString()
var sb = new StringBuilder();
for (int i = 0; i < _n; i++)
for (int j = 0; j < _m; j++)
sb.Append(Environment.NewLine);
internal static class MatrixMultiplier
public static Matrix MultiplyBasic(Matrix mat1, Matrix mat2)
throw new ArgumentException();
var result = new Matrix(mat1.n, mat2.m);
for (int n = 0; n < result.n; n++)
for (int m = 0; m < result.m; m++)
for (int i = 0; i < mat1.m; i++)
sum += mat1[n,i]*mat2[i,m];
public static Matrix MultiplyStrassen(Matrix mat1, Matrix mat2)
throw new ArgumentException();
var emat1 = ExplodeMatrix(mat1);
var emat2 = ExplodeMatrix(mat2);
var product = MultiplyStrassenRecursive(emat1, emat2);
var result = new Matrix(mat1.n, mat1.n);
CopySubmatrix(product, result, 0, 0, mat1.n-1, mat1.n-1);
private static Matrix MultiplyStrassenRecursive(Matrix mat1, Matrix mat2) {
if (mat1.n == 2 || mat1.m == 2 || mat2.n == 2 || mat2.m == 2) {
return MultiplyBasic(mat1, mat2);
Matrix a11, a12, a21, a22;
Matrix b11, b12, b21, b22;
SplitMatrix(mat1, out a11, out a12, out a21, out a22);
SplitMatrix(mat2, out b11, out b12, out b21, out b22);
var p1 = MultiplyStrassenRecursive(a11+a22, b11+b22);
var p2 = MultiplyStrassenRecursive(a21+a22, b11);
var p3 = MultiplyStrassenRecursive(a11, b12-b22);
var p4 = MultiplyStrassenRecursive(a22, b21-b11);
var p5 = MultiplyStrassenRecursive(a11+a12, b22);
var p6 = MultiplyStrassenRecursive(a21-a11, b11+b12);
var p7 = MultiplyStrassenRecursive(a12-a22, b21+b22);
var c = CombineMatrix(c11, c12, c21, c22);
private static Matrix ExplodeMatrix(Matrix mat)
var n = FindPowerOfTwo(mat.n);
var m = FindPowerOfTwo(mat.m);
var result = new Matrix(max,max);
for (int i = 0; i < mat.n; i++)
for (int j = 0; j < mat.m; j++)
private static void SplitMatrix(
if (mat.n%2 == 1 || mat.m%2 == 1)
throw new ArgumentException();
submat11 = new Matrix(mat.n/2, mat.m/2);
submat12 = new Matrix(mat.n/2, mat.m/2);
submat21 = new Matrix(mat.n/2, mat.m/2);
submat22 = new Matrix(mat.n/2, mat.m/2);
CopySubmatrix(mat, submat11, 0, 0, mat.n/2-1, mat.m/2-1);
CopySubmatrix(mat, submat12, 0, mat.m/2, mat.n/2-1, mat.m-1);
CopySubmatrix(mat, submat21, mat.n/2, 0, mat.n-1, mat.m/2-1);
CopySubmatrix(mat,submat22, mat.n/2, mat.m/2, mat.n-1, mat.m-1);
private static Matrix CombineMatrix(Matrix a11, Matrix a12, Matrix a21, Matrix a22)
var result = new Matrix(n*2, m*2);
for (int i = 0; i < 2*n; i++)
for (int j = 0; j < m; j++)
result[i,j] = i < n ? a11[i,j] : a21[i-n,j];
for (int j = m; j < 2*m; j++)
result[i,j] = i < n ? a12[i,j-m] : a22[i-n,j-m];
private static void CopySubmatrix(
for (int n = startN; n <= endN; n++)
for (int m = startM; m <= endM; m++)
target[n-startN, m-startM] = source[n,m];
private static int FindPowerOfTwo(int num)
currentnum = currentnum/2;
if ((num&(1<<i)) == num) return num;