using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
public static void Main()
IQueryable<Sample> source = GetSamples();
DateTime? limit = DateTime.Now;
from sample in source.Extend()
where sample.Timestamp.AtLeast(limit)
Date = sample.Timestamp.Date,
Time = sample.Timestamp.TimeOfDay,
var modified = query.Unextend();
Console.WriteLine("Unmodified:");
Console.WriteLine(query.Expression.ToString());
Console.WriteLine("Modified:");
Console.WriteLine(modified.Expression.ToString());
var directResults = query.ToArray();
var modifiedResults = modified.ToArray();
bool different = directResults.Length != modifiedResults.Length;
for (int i = 0; !different && i < directResults.Length; i++)
var row = directResults[i];
Console.WriteLine($"{i}\t{row.Index}\t{row.Date}\t{row.Time}");
different |= row.Index != modifiedResults[i].Index;
Console.WriteLine($"Results {(different ? "do not " : "")}match.");
static readonly List<Sample> _samples =
.Select(i => new Sample(i, DateTime.Today.AddDays(-2).AddHours(i)))
public static IQueryable<Sample> GetSamples() => _samples.AsQueryable();
public record Sample(int Index, DateTime Timestamp);
public static class Extensions
public static IQueryable<T> Extend<T>(this IQueryable<T> query)
=> query is ExtendedQueryable<T> ? query : new ExtendedQueryable<T>(query);
public static IQueryable<T> Unextend<T>(this IQueryable<T> query)
=> query is ExtendedQueryable<T> extended ? extended.GetTransformedQuery() : query;
[Expression(nameof(_atLeast))]
public static bool AtLeast(this DateTime value, DateTime? minValue)
=> minValue is null || value >= minValue;
private static Expression<Func<DateTime, DateTime?, bool>> _atLeast =
(value, minValue) => minValue == null || value >= minValue;
[AttributeUsage(AttributeTargets.Method)]
public sealed class ExpressionAttribute : Attribute
public string MemberName { get; init; }
public ExpressionAttribute(string memberName)
=> MemberName = memberName;
public class ExtendedQueryable<T> : IQueryable<T>, IQueryProvider
private readonly IQueryable<T> _wrapped;
#region IQueryable<T> properties
public Type ElementType => _wrapped.ElementType;
public Expression Expression => _wrapped.Expression;
public IQueryProvider Provider => this;
public ExtendedQueryable(IQueryable<T> wrapped)
public IQueryable<T> GetTransformedQuery()
Expression transformed = TransformVisitor.Transform(_wrapped.Expression);
return _wrapped.Provider.CreateQuery<T>(transformed);
#region IQueryable<T> methods
public IEnumerator<T> GetEnumerator()
Expression transformed = TransformVisitor.Transform(_wrapped.Expression);
IQueryable<T> query = _wrapped.Provider.CreateQuery<T>(transformed);
return query.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
#region IQueryProvider methods
public IQueryable CreateQuery(Expression expression) => CreateQuery<T>(expression);
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
var inner = _wrapped.Provider.CreateQuery<TElement>(expression);
return new ExtendedQueryable<TElement>(inner);
public object? Execute(Expression expression) => Execute<T>(expression);
public TResult Execute<TResult>(Expression expression)
Expression transformed = TransformVisitor.Transform(_wrapped.Expression);
return _wrapped.Provider.Execute<TResult>(transformed);
internal sealed class TransformVisitor : ExpressionVisitor
private TransformVisitor()
public static T Transform<T>(T expression)
TransformVisitor visitor = new();
return (T)visitor.Visit(expression);
protected override Expression VisitMethodCall(MethodCallExpression node)
if (node.Method.GetCustomAttribute<ExpressionAttribute>() is not ExpressionAttribute attr)
return base.VisitMethodCall(node);
MemberInfo? source = node.Method.DeclaringType
.GetMember(attr.MemberName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static)
return base.VisitMethodCall(node);
LambdaExpression? expression = source switch
FieldInfo fi => fi.GetValue(null),
PropertyInfo pi => pi.GetValue(null),
MethodInfo mi => mi.Invoke(null, []),
return base.VisitMethodCall(node);
List<Expression> args = node.Arguments.ToList();
if (node.Object is not null)
args.Insert(0, node.Object);
var replacements = expression.Parameters
.Select((p, i) => (p, a: args[i]))
.ToDictionary(row => (Expression)row.p, row => row.a);
return ReplaceVisitor.Replace(expression.Body, replacements);
internal sealed class ReplaceVisitor : ExpressionVisitor
private readonly IReadOnlyDictionary<Expression, Expression> _replacements;
private ReplaceVisitor(IReadOnlyDictionary<Expression, Expression> replacements)
=> _replacements = replacements;
public static T Replace<T>(T expression, IReadOnlyDictionary<Expression, Expression> replacements)
ReplaceVisitor visitor = new(replacements);
return (T)visitor.Visit(expression);
[return: NotNullIfNotNull(nameof(node))]
public override Expression? Visit(Expression? node)
if (node is not null && _replacements.TryGetValue(node, out var replacement))