private readonly double[,] Wi;
private readonly double[,] Wf;
private readonly double[,] Wc;
private readonly double[,] Wo;
private readonly double[] bi;
private readonly double[] bf;
private readonly double[] bc;
private readonly double[] bo;
private const int InputSize = 1;
private const int HiddenSize = 1;
Wi = InitializeMatrix(HiddenSize + InputSize, HiddenSize, 0.5);
Wf = InitializeMatrix(HiddenSize + InputSize, HiddenSize, 0.5);
Wc = InitializeMatrix(HiddenSize + InputSize, HiddenSize, 0.5);
Wo = InitializeMatrix(HiddenSize + InputSize, HiddenSize, 0.5);
bi = InitializeArray(HiddenSize, 0.5);
bf = InitializeArray(HiddenSize, 0.5);
bc = InitializeArray(HiddenSize, 0.5);
bo = InitializeArray(HiddenSize, 0.5);
private double[,] InitializeMatrix(int rows, int cols, double value)
double[,] matrix = new double[rows, cols];
for (int i = 0; i < rows; i++)
for (int j = 0; j < cols; j++)
private double[] InitializeArray(int size, double value)
double[] array = new double[size];
for (int i = 0; i < size; i++)
private double Sigmoid(double x)
return 1 / (1 + Math.Exp(-x));
private double Tanh(double x)
public double Process(double input, double previousHiddenState, double previousCellState)
double[] concat = new double[HiddenSize + InputSize];
concat[0] = previousHiddenState;
double it = Sigmoid(MatrixVectorProduct(Wi, concat)[0] + bi[0]);
double ft = Sigmoid(MatrixVectorProduct(Wf, concat)[0] + bf[0]);
double ct_candidate = Tanh(MatrixVectorProduct(Wc, concat)[0] + bc[0]);
double ct = ft * previousCellState + it * ct_candidate;
double ot = Sigmoid(MatrixVectorProduct(Wo, concat)[0] + bo[0]);
double ht = ot * Tanh(ct);
private double[] MatrixVectorProduct(double[,] matrix, double[] vector)
int rows = matrix.GetLength(0);
int cols = matrix.GetLength(1);
int vectorSize = vector.Length;
throw new ArgumentException("Matrix and vector dimensions are incompatible.");
double[] result = new double[cols];
for (int j = 0; j < cols; j++)
for (int i = 0; i < rows; i++)
sum += matrix[i, j] * vector[i];
public class LSTMCellTests
public void Process_PositiveInput_ReturnsExpectedHiddenState()
LSTMCell lstmCell = new LSTMCell();
double previousHiddenState = 0.1;
double previousCellState = 0.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.Not.NaN);
public void Process_ZeroInput_ReturnsExpectedHiddenState()
LSTMCell lstmCell = new LSTMCell();
double previousHiddenState = 0.0;
double previousCellState = 0.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.Not.NaN);
public void Process_LargeInput_ReturnsExpectedHiddenState()
LSTMCell lstmCell = new LSTMCell();
double previousHiddenState = 10.0;
double previousCellState = 1.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.Not.NaN);
public void Process_NegativeInput_ReturnsExpectedHiddenState()
LSTMCell lstmCell = new LSTMCell();
double previousHiddenState = -0.1;
double previousCellState = -0.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.Not.NaN);
public void Process_NaNInput_ReturnsNaN()
LSTMCell lstmCell = new LSTMCell();
double input = double.NaN;
double previousHiddenState = 0.1;
double previousCellState = 0.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.NaN);
public void Process_MaxValueInput_ReturnsExpectedHiddenState()
LSTMCell lstmCell = new LSTMCell();
double input = double.MaxValue;
double previousHiddenState = 0.1;
double previousCellState = 0.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.Not.NaN);
public void Process_MinValueInput_ReturnsExpectedHiddenState()
LSTMCell lstmCell = new LSTMCell();
double input = double.MinValue;
double previousHiddenState = 0.1;
double previousCellState = 0.0;
double newHiddenState = lstmCell.Process(input, previousHiddenState, previousCellState);
Assert.That(newHiddenState, Is.Not.NaN);
public void Process_DifferentPreviousStates_ReturnsDifferentHiddenState()
LSTMCell lstmCell = new LSTMCell();
double previousHiddenState1 = 0.1;
double previousCellState1 = 0.0;
double previousHiddenState2 = 0.2;
double previousCellState2 = 0.1;
double newHiddenState1 = lstmCell.Process(input, previousHiddenState1, previousCellState1);
double newHiddenState2 = lstmCell.Process(input, previousHiddenState2, previousCellState2);
Assert.That(newHiddenState1, Is.Not.EqualTo(newHiddenState2));
public static int Main(string[] args)
return new AutoRun().Execute(args);