using System.Collections.ObjectModel;
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query;
namespace ExpressionVisitorTests
public static void Main()
var colourRepository = new ColourRepository();
var result = colourRepository.GetWhere(c => c.Name == "Red");
Console.WriteLine(result.Count());
public DtoColour(string name)
public string Name { get; set; }
public class DomainColour {
public DomainColour(string name)
public string Name { get; set; }
public class ColourRepository {
private IList<DtoColour> Colours { get; set; }
public ColourRepository()
Colours = new List<DtoColour>()
public IEnumerable<DomainColour> GetWhere(Expression<Func<DomainColour, bool>> predicate)
var coonvertedPred = ModelBuilderExtensions.Convert(predicate);
return Colours.Where(coonvertedPred).Select(c => new DomainColour(c.Name)).ToList();
public static class ModelBuilderExtensions
static readonly MethodInfo SetQueryFilterMethod = typeof(ModelBuilderExtensions)
.GetMethods(BindingFlags.NonPublic | BindingFlags.Static)
.Single(t => t.IsGenericMethod && t.Name == nameof(SetQueryFilter));
public static void SetQueryFilterOnAllEntities<TEntityInterface>(
this ModelBuilder builder,
Expression<Func<TEntityInterface ,bool>> filterExpression)
foreach (var type in builder.Model.GetEntityTypes()
.Where(t => t.BaseType == null)
.Where(t => typeof(TEntityInterface).IsAssignableFrom(t)))
builder.SetEntityQueryFilter(
static void SetEntityQueryFilter<TEntityInterface>(
this ModelBuilder builder,
Expression<Func<TEntityInterface, bool>> filterExpression)
.MakeGenericMethod(entityType, typeof(TEntityInterface))
.Invoke(null, new object[] { builder, filterExpression });
static void SetQueryFilter<TEntity, TEntityInterface>(
this ModelBuilder builder,
Expression<Func<TEntityInterface, bool>> filterExpression)
where TEntityInterface : class
where TEntity : class, TEntityInterface
var concreteExpression = filterExpression
.Convert<TEntityInterface, TEntity>();
builder.Entity<TEntity>()
.AppendQueryFilter(concreteExpression);
static void AppendQueryFilter<T>(this EntityTypeBuilder entityTypeBuilder, Expression<Func<T, bool>> expression)
var parameterType = Expression.Parameter(entityTypeBuilder.Metadata.ClrType);
var expressionFilter = ReplacingExpressionVisitor.Replace(
expression.Parameters.Single(), parameterType, expression.Body);
if (entityTypeBuilder.Metadata.GetQueryFilter() != null)
var currentQueryFilter = entityTypeBuilder.Metadata.GetQueryFilter();
var currentExpressionFilter = ReplacingExpressionVisitor.Replace(
currentQueryFilter.Parameters.Single(), parameterType, currentQueryFilter.Body);
expressionFilter = Expression.AndAlso(currentExpressionFilter, expressionFilter);
var lambdaExpression = Expression.Lambda(expressionFilter, parameterType);
entityTypeBuilder.HasQueryFilter(lambdaExpression);
public static class ExpressionExtensions
public static Expression<Func<TTarget, bool>> Convert<TSource, TTarget>(
this Expression<Func<TSource, bool>> root)
var visitor = new ParameterTypeVisitor<TSource, TTarget>();
Console.WriteLine(visitor);
var data = visitor.Visit(root);
return (Expression<Func<TTarget, bool>>)data;
class ParameterTypeVisitor<TSource, TTarget> : ExpressionVisitor
private ReadOnlyCollection<ParameterExpression> _parameters;
protected override Expression VisitParameter(ParameterExpression node)
return _parameters?.FirstOrDefault(p => p.Name == node.Name)
?? (node.Type == typeof(TSource)
? Expression.Parameter(typeof(TTarget), node.Name) : node);
protected override Expression VisitLambda<T>(Expression<T> node)
_parameters = VisitAndConvert(node.Parameters, "VisitLambda");
return Expression.Lambda(Visit(node.Body), _parameters);