using System.Collections.Generic;
using System.Globalization;
using Microsoft.ML.Trainers;
const int CurrentUserId = 1000001;
const int NumRecommendations = 20;
const double TraingTestDataSplitFraction = 0.2;
const int TraingTestDataSplitRandomSeed = 42;
private static List<Movie> _movies;
private static Dictionary<string, Movie> _moviesTitleToMovie;
static Dictionary<string, double?> CurrentUserMovieRatings
= new Dictionary<string, double?>()
{"Arrival (2016)", null},
{"Doctor Strange (2016)", null},
{"Guardians of the Galaxy 2 (2017)", 3.5},
{"Rogue One: A Star Wars Story (2016)", null},
{"Captain America: Civil War (2016)", null},
{"Thor: Ragnarok (2017)", null},
{"Avengers: Infinity War - Part I (2018)", null},
{"Wonder Woman (2017)", null},
{"Blade Runner 2049 (2017)", null},
{"Baby Driver (2017)", 2},
{"Spider-Man: Homecoming (2017)", null},
{"Black Panther (2017)", 3.5},
{"Dunkirk (2017)", null},
{"Fantastic Beasts and Where to Find Them (2016)", null},
{"Deadpool 2 (2018)", null},
{"Star Wars: The Last Jedi (2017)", null},
{"Three Billboards Outside Ebbing, Missouri (2017)", null},
{"La La Land (2016)", 4},
{"Suicide Squad (2016)", 1},
{"The Nice Guys (2016)", 4.5},
{"10 Cloverfield Lane (2016)", 4},
{"John Wick: Chapter Two (2017)", 3},
{"Passengers (2016)", null},
{"X-Men: Apocalypse (2016)", null},
{"Star Trek Beyond (2016)", null},
{"Batman v Superman: Dawn of Justice (2016)", 2},
{"A Quiet Place (2018)", null},
{"The Shape of Water (2017)", 2},
{"Avengers: Infinity War - Part II (2019)", null},
{"Hacksaw Ridge (2016)", null},
{"Annihilation (2018)", null},
{"The Accountant (2016)", null},
{"Hidden Figures (2016)", null},
{"Now You See Me 2 (2016)", null},
{"Ant-Man and the Wasp (2018)", 3},
{"Hell or High Water (2016)", null},
{"Lady Bird (2017)", null},
{"Kingsman: The Golden Circle (2017)", 3.5},
{"Jumanji: Welcome to the Jungle (2017)", null},
{"The Jungle Book (2016)", null},
{"Bohemian Rhapsody (2018)", null},
{"Ghost in the Shell (2017)", null},
{"Manchester by the Sea (2016)", null},
{"Solo: A Star Wars Story (2018)", null},
{"Mission: Impossible - Fallout (2018)", 3.5},
public static void Main()
MLContext mlContext = new MLContext();
if (!ValidateCurrentUserMovieRatings())
Console.WriteLine("\nLoad movies");
Console.WriteLine("=========================================\n");
Console.WriteLine("\nLoad training and test data");
Console.WriteLine("=========================================\n");
(IDataView trainDataView, IDataView testDataView) = LoadData1(mlContext);
Console.WriteLine("\nTrain the model using training data set");
Console.WriteLine("=========================================\n");
ITransformer model = BuildAndTrainModel(mlContext, trainDataView);
Console.WriteLine("\nEvaluate the model using test data set");
Console.WriteLine("=========================================\n");
EvaluateModel(mlContext, testDataView, model);
Console.WriteLine("\nMake predictions");
Console.WriteLine("=========================================\n");
MakePredictions(mlContext, model);
public static (IDataView trainDataView, IDataView testDataView) LoadData1(MLContext mlContext)
var ratingsCsvFileUrl = @"https://pandatechprod.blob.core.windows.net/dnfnewsletter/WarGames/Ratings.csv";
var ratingsDataList = LoadListFromUrlCsv<MovieRating>(ratingsCsvFileUrl);
Console.WriteLine($"RatingsList size : {ratingsDataList.Count}\n");
var currentUserMovieRatings = GetCurrentUserMovieRatings();
ratingsDataList.AddRange(currentUserMovieRatings);
IDataView ratingsDataView = mlContext.Data.LoadFromEnumerable<MovieRating>(ratingsDataList);
DataOperationsCatalog.TrainTestData dataSplit = mlContext.Data.TrainTestSplit(ratingsDataView,
testFraction: TraingTestDataSplitFraction,
seed: TraingTestDataSplitRandomSeed);
IDataView trainDataView = dataSplit.TrainSet;
IDataView testDataView = dataSplit.TestSet;
return (trainDataView, testDataView);
public static void LoadMovies()
var moviesCsvFileUrl = @"https://pandatechprod.blob.core.windows.net/dnfnewsletter/WarGames/Movies.csv";
var moviesList = LoadListFromUrlCsv<Movie>(moviesCsvFileUrl);
Console.WriteLine($"MoviesList size : {moviesList.Count}\n");
_moviesTitleToMovie = moviesList.ToDictionary(ml => ml.Title, ml => ml);
public static List<MovieRating> GetCurrentUserMovieRatings()
var ratings = new List<MovieRating>();
foreach (var pair in CurrentUserMovieRatings)
if (!_moviesTitleToMovie.TryGetValue(pair.Key, out movie))
Console.WriteLine($"Title {pair.Key} does not match movie. Skipping.");
ratings.Add(new MovieRating
Rating = Convert.ToSingle(pair.Value.Value)
public static ITransformer BuildAndTrainModel(MLContext mlContext, IDataView trainingDataView)
IEstimator<ITransformer> estimator = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "UserIdEncoded", inputColumnName: "UserId")
.Append(mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "MovieIdEncoded", inputColumnName: "MovieId"));
var options = new MatrixFactorizationTrainer.Options
MatrixColumnIndexColumnName = "UserIdEncoded",
MatrixRowIndexColumnName = "MovieIdEncoded",
LabelColumnName = "Rating",
var trainerEstimator = estimator.Append(mlContext.Recommendation().Trainers.MatrixFactorization(options));
ITransformer model = trainerEstimator.Fit(trainingDataView);
Console.WriteLine("\nIn this output, there are 30 iterations. In each iteration, the measure of error decreases and converges closer and closer to 0.");
public static void EvaluateModel(MLContext mlContext, IDataView testDataView, ITransformer model)
var prediction = model.Transform(testDataView);
var metrics = mlContext.Regression.Evaluate(prediction, labelColumnName: "Rating", scoreColumnName: "Score");
Console.WriteLine("Root Mean Squared Error: " + metrics.RootMeanSquaredError.ToString());
Console.WriteLine("- used to measure the differences between the model predicted values and the test dataset observed values. The lower it is, the better the model is.");
Console.WriteLine("\nRSquared: " + metrics.RSquared.ToString());
Console.WriteLine("- indicates how well data fits a model. A value of 0 means data is random, value of 1 means that model matches data exactly.");
public static void MakePredictions(MLContext mlContext, ITransformer model)
var moviePredictions = new List<MoviePrediction>();
foreach (Movie movie in _movies)
if (CurrentUserMovieRatings.ContainsKey(movie.Title))
var predictedRating = MakePrediction(mlContext, model, CurrentUserId, movie.Id);
moviePredictions.Add(new MoviePrediction { Movie = movie, Rating = predictedRating });
var topMovieRecommendations = moviePredictions.OrderByDescending(mp => mp.Rating).Take(NumRecommendations);
Console.WriteLine($"Top Recommendations:\n");
WriteRecommendations(topMovieRecommendations);
var bottomMovieRecommendations = moviePredictions.OrderBy(mp => mp.Rating).Take(NumRecommendations);
Console.WriteLine($"\nBottom Un-Recommendations:\n");
WriteRecommendations(bottomMovieRecommendations);
private static void WriteRecommendations(IEnumerable<MoviePrediction> movieRecommendations)
Console.WriteLine($"{"Title",-40} {"Predicted rating",-20} {"Total Ratings",-20} {"Avg rating",-20}");
foreach (var movieRecommendation in movieRecommendations)
Console.WriteLine($"{movieRecommendation.Movie.Title,-40} {movieRecommendation.Rating, -20:0.00} {movieRecommendation.Movie.RatingsCount, -20} {movieRecommendation.Movie.RatingsAvg,-20:0.00}");
public static float MakePrediction(MLContext mlContext, ITransformer model, int userId, int movieId)
var predictionEngine = mlContext.Model.CreatePredictionEngine<MovieRating, MovieRatingPrediction>(model);
var testInput = new MovieRating { UserId = userId, MovieId = movieId };
var movieRatingPrediction = predictionEngine.Predict(testInput);
return movieRatingPrediction.Score;
public static List<T> LoadListFromUrlCsv<T>(string cvsFileUrl)
HttpWebRequest req = (HttpWebRequest)WebRequest.Create(cvsFileUrl);
req.ProtocolVersion = HttpVersion.Version10;
HttpWebResponse resp = (HttpWebResponse)req.GetResponse();
using (var streamReader = new StreamReader(resp.GetResponseStream()))
using (var csvReader = new CsvReader(streamReader, CultureInfo.InvariantCulture))
csvReader.Configuration.HasHeaderRecord = false;
var records = csvReader.GetRecords<T>();
return new List<T>(records);
public static bool ValidateCurrentUserMovieRatings()
if (CurrentUserMovieRatings.Values.Count(v => v != null) < 10)
Console.WriteLine("Please rate at least 10 movies");
if (CurrentUserMovieRatings.Values.Any(v => v != null && (v.Value < 0.5 || v.Value > 5)))
Console.WriteLine("Please make sure your ratings are between 1 and 5");
if (CurrentUserMovieRatings.Values.Any(v => v != null && (v.Value % .5 != 0)))
Console.WriteLine("Please make sure your rating is divisible by .5");
public int Id { get; set; }
public string Title { get; set; }
public string Genres { get; set; }
public int RatingsCount { get; set; }
public float RatingsAvg { get; set; }
public int UserId { get; set; }
public int MovieId { get; set; }
public float Rating { get; set; }
public class MovieRatingPrediction
public class MoviePrediction