C#-method of converting lambda expression to where conditional sql statement

The original text goes to: http://www.cnblogs.com/FengCodes/p/LambdaToSqlWhere.html

Changes

/// <summary>
    /// 根据Expression表达式生成SQL-Where部分的语句
    /// </summary>
    public class SqlGenerate
    {
        /// <summary>
        /// 生成SQL-Where语句
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="expression">表达式</param>
        /// <param name="databaseType">数据类型(用于字段是否加引号)</param>
        /// <returns></returns>
        public string GetWhereByLambda<T>(Expression<Func<T, bool>> expression, string databaseType = "SqlServer")
        {
            bool withQuotationMarks = GetWithQuotationMarks(databaseType);
            ConditionBuilder conditionBuilder = new ConditionBuilder();
            conditionBuilder.WithQuotationMarks = withQuotationMarks; //字段是否加引号(PostGreSql,Oracle)
            conditionBuilder.DataBaseType = databaseType.ToEnum<DataBaseType>(true);
            conditionBuilder.Build(expression);

            object ce = null;
            for (int i = 0; i < conditionBuilder.Arguments.Length; i++)
            {
                ce = conditionBuilder.Arguments[i];
                if (ce == null)
                {
                    conditionBuilder.Arguments[i] = DBNull.Value;
                    continue;
                }
                if (ce is string || ce is char)
                {
                    bool isQuote = ce.ToString().ToLower().Trim().IndexOf(@"in(") == 0 ||
                        ce.ToString().ToLower().Trim().IndexOf(@"not in(") == 0 ||
                        ce.ToString().ToLower().Trim().IndexOf(@" like '") == 0 ||
                        ce.ToString().ToLower().Trim().IndexOf(@"not like") == 0;
                    conditionBuilder.Arguments[i] = string.Format(" {1}{0}{2} ", ce.ToString(), isQuote ? "" : "'", isQuote ? "" : "'");
                    continue;
                }
                if (ce is int || ce is long || ce is short || ce is decimal || ce is double || ce is float || ce is bool || ce is byte || ce is sbyte || ce is ValueType)
                {
                    conditionBuilder.Arguments[i] = ce.ToString();
                    continue;
                }
                conditionBuilder.Arguments[i] = string.Format("'{0}'", ce.ToString());
            }
            return string.Format(conditionBuilder.Condition, conditionBuilder.Arguments);
        }

        /// <summary>
        /// 获取是否字段加双引号
        /// </summary>
        /// <param name="databaseType"></param>
        /// <returns></returns>
        private bool GetWithQuotationMarks(string databaseType)
        {
            bool result = false;
            switch (databaseType.ToEnum<DataBaseType>(true))
            {
                case DataBaseType.PostGreSql:
                case DataBaseType.Oracle:
                    result = true;
                    break;
            }
            return result;
        }
    }
 /// <summary>
    /// 脚本生成实体
    /// </summary>
    internal sealed class ConditionBuilder : ExpressionVisitor
    {
        /// <summary>
        /// 数据库类型
        /// </summary>
        private DataBaseType dataBaseType = DataBaseType.SqlServer;
        /// <summary>
        /// 字段是否加引号
        /// </summary>
        private bool withQuotationMarks = false;
        private List<object> arguments;
        private Stack<string> conditionParts;

        /// <summary>
        /// 加双引号
        /// </summary>
        /// <param name="str">字串</param>
        /// <returns></returns>
        private string AddQuotationMarks(string str)
        {
            if (str.IsEmpty() && withQuotationMarks)
                return "\"" + str.Trim() + "\"";
            return str;
        }

        /// <summary>
        /// 执行生成
        /// </summary>
        /// <param name="expression"></param>
        public void Build(Expression expression)
        {
            PartialEvaluator evaluator = new PartialEvaluator();
            Expression evaluatedExpression = evaluator.Eval(expression);
            this.arguments = new List<object>();
            this.conditionParts = new Stack<string>();
            this.Visit(evaluatedExpression);
        }

        protected sealed override Expression VisitBinary(BinaryExpression b)
        {
            if (b == null)
                return b;
            string opr;
            switch (b.NodeType)
            {
                case ExpressionType.Equal:
                    opr = "=";
                    break;
                case ExpressionType.NotEqual:
                    opr = "<>";
                    break;
                case ExpressionType.GreaterThan:
                    opr = ">";
                    break;
                case ExpressionType.GreaterThanOrEqual:
                    opr = ">=";
                    break;
                case ExpressionType.LessThan:
                    opr = "<";
                    break;
                case ExpressionType.LessThanOrEqual:
                    opr = "<=";
                    break;
                case ExpressionType.AndAlso:
                    opr = "and";
                    break;
                case ExpressionType.OrElse:
                    opr = "or";
                    break;
                case ExpressionType.Add:
                    opr = "+";
                    break;
                case ExpressionType.Subtract:
                    opr = "-";
                    break;
                case ExpressionType.Multiply:
                    opr = "*";
                    break;
                case ExpressionType.Divide:
                    opr = "/";
                    break;
                default:
                    throw new NotSupportedException(b.NodeType + "is not supported.");
            }

            this.Visit(b.Left);
            this.Visit(b.Right);
            string right = this.conditionParts.Pop();
            string left = this.conditionParts.Pop();
            string condition = String.Format("({0} {1} {2})", left, opr, right);
            this.conditionParts.Push(condition);
            return b;
        }

        protected sealed override Expression VisitConstant(ConstantExpression c)
        {
            if (c == null)
                return c;
            this.arguments.Add(c.Value);
            this.conditionParts.Push(String.Format("{
   
   {
   
   {0}}}", this.arguments.Count - 1));
            return c;
        }

        protected sealed override Expression VisitMemberAccess(MemberExpression m)
        {
            if (m == null)
                return m;
            PropertyInfo propertyInfo = m.Member as PropertyInfo;
            if (propertyInfo == null)
                return m;
            this.conditionParts.Push(String.Format(" {0} ", AddQuotationMarks(propertyInfo.Name)));
            return m;
        }

        private string BinarExpressionProvider(Expression left, Expression right, ExpressionType type)
        {
            string stringInfo = "(";
            //先处理左边
            stringInfo += ExpressionRouter(left);
            stringInfo += ExpressionTypeCast(type);
            //再处理右边
            string tmpStr = ExpressionRouter(right);
            if (tmpStr.SameAS("null"))
            {
                if (stringInfo.EndsWith(" ="))
                    stringInfo = stringInfo.Substring(0, stringInfo.Length - 1) + " is null";
                else if (stringInfo.EndsWith("<>"))
                    stringInfo = stringInfo.Substring(0, stringInfo.Length - 1) + " is not null";
            }
            else
                stringInfo += tmpStr;
            return stringInfo += ")";
        }

        private string ExpressionRouter(Expression exp)
        {
            if (exp is BinaryExpression)
            {
                BinaryExpression be = ((BinaryExpression)exp);
                return BinarExpressionProvider(be.Left, be.Right, be.NodeType);
            }
            if (exp is MemberExpression)
            {
                MemberExpression me = ((MemberExpression)exp);
                return me.Member.Name;
            }
            if (exp is NewArrayExpression)
            {
                NewArrayExpression ae = ((NewArrayExpression)exp);
                StringBuilder tmpstr = new StringBuilder();
                foreach (Expression ex in ae.Expressions)
                {
                    tmpstr.Append(ExpressionRouter(ex));
                    tmpstr.Append(",");
                }
                return tmpstr.ToString(0, tmpstr.Length - 1);
            }
            if (exp is MethodCallExpression)
            {
                MethodCallExpression mce = (MethodCallExpression)exp;
                if (mce.Method.Name.SameAS("Like"))
                    return string.Format("({0} like {1})", ExpressionRouter(mce.Arguments[0]), ExpressionRouter(mce.Arguments[1]));
                else if (mce.Method.Name.SameAS("NotLike"))
                    return string.Format("({0} not like {1})", ExpressionRouter(mce.Arguments[0]), ExpressionRouter(mce.Arguments[1]));
                else if (mce.Method.Name.SameAS("In"))
                    return string.Format("{0} in ({1})", ExpressionRouter(mce.Arguments[0]), ExpressionRouter(mce.Arguments[1]));
                else if (mce.Method.Name.SameAS("NotIn"))
                    return string.Format("{0} not in ({1})", ExpressionRouter(mce.Arguments[0]), ExpressionRouter(mce.Arguments[1]));
                else if (mce.Method.Name.SameAS("StartWith"))
                    return string.Format("{0} like '{1}%'", ExpressionRouter(mce.Arguments[0]), ExpressionRouter(mce.Arguments[1]));
                return null;
            }
            if (exp is ConstantExpression)
            {
                ConstantExpression ce = ((ConstantExpression)exp);
                if (ce.Value == null)
                    return "null";
                else if (ce.Value is ValueType)
                    return ce.Value.ToString();
                else if (ce.Value is string || ce.Value is DateTime || ce.Value is char)
                    return string.Format("'{0}'", ce.Value.ToString());
                return null;
            }
            if (exp is UnaryExpression)
            {
                UnaryExpression ue = ((UnaryExpression)exp);
                return ExpressionRouter(ue.Operand);
            }
            return null;
        }

        private string ExpressionTypeCast(ExpressionType type)
        {
            switch (type)
            {
                case ExpressionType.And:
                case ExpressionType.AndAlso:
                    return " and ";
                case ExpressionType.Equal:
                    return " =";
                case ExpressionType.GreaterThan:
                    return " >";
                case ExpressionType.GreaterThanOrEqual:
                    return ">=";
                case ExpressionType.LessThan:
                    return "<";
                case ExpressionType.LessThanOrEqual:
                    return "<=";
                case ExpressionType.NotEqual:
                    return "<>";
                case ExpressionType.Or:
                case ExpressionType.OrElse:
                    return " or ";
                case ExpressionType.Add:
                case ExpressionType.AddChecked:
                    return "+";
                case ExpressionType.Subtract:
                case ExpressionType.SubtractChecked:
                    return "-";
                case ExpressionType.Divide:
                    return "/";
                case ExpressionType.Multiply:
                case ExpressionType.MultiplyChecked:
                    return "*";
                default:
                    return null;
            }
        }

        /// <summary>
        /// ConditionBuilder 并不支持生成Like操作,如 字符串的 StartsWith,Contains,EndsWith 并不能生成这样的SQL: Like ‘xxx%’, Like ‘%xxx%’ , Like ‘%xxx’ . 只要override VisitMethodCall 这个方法即可实现上述功能。
        /// </summary>
        /// <param name="m"></param>
        /// <returns></returns>
        protected sealed override Expression VisitMethodCall(MethodCallExpression m)
        {
            if (m == null)
                return m;
            string connectorWords = GetLikeConnectorWords(dataBaseType);
            string format;
            switch (m.Method.Name.ToUpper())
            {
                case "StartsWith":
                    format = "({0} like ''" + connectorWords + "{1}" + connectorWords + "'%')";
                    break;
                case "Contains":
                    format = "({0} like '%'" + connectorWords + "{1}" + connectorWords + "'%')";
                    break;
                case "EndsWith":
                    format = "({0} like '%'" + connectorWords + "{1}" + connectorWords + "'')";
                    break;
                case "Equals":
                    format = "({0} {1} )";
                    break;
                default:
                    throw new NotSupportedException(m.NodeType + " is not supported!");
            }
            this.Visit(m.Object);
            this.Visit(m.Arguments[0]);
            string right = this.conditionParts.Pop();
            string left = this.conditionParts.Pop();
            this.conditionParts.Push(String.Format(format, left, right));
            return m;
        }

        /// <summary>
        /// 获得like语句链接符
        /// </summary>
        /// <param name="databaseType"></param>
        /// <returns></returns>
        private string GetLikeConnectorWords(DataBaseType databaseType)
        {
            string result = "+";
            switch (databaseType)
            {
                case DataBaseType.PostGreSql:
                case DataBaseType.Oracle:
                case DataBaseType.MySql:
                case DataBaseType.Sqlite:
                    result = "||";
                    break;
            }
            return result;
        }

        /// <summary>
        /// 获取或者设置数据库类型
        /// </summary>
        public DataBaseType DataBaseType
        {
            get { return this.dataBaseType; }
            set { this.dataBaseType = value; }
        }

        /// <summary>
        /// 
        /// </summary>
        public string Condition
        {
            get
            {
                if (this.conditionParts.Count > 0)
                    return this.conditionParts.Pop();
                return string.Empty;
            }
        }

        /// <summary>
        /// 
        /// </summary>
        public object[] Arguments
        {
            get { return this.arguments.ToArray(); }
        }

        /// <summary>
        /// 字段是否加引号
        /// </summary>
        public bool WithQuotationMarks
        {
            get { return this.withQuotationMarks; }
            set { this.withQuotationMarks = value; }
        }


    }

 /// <summary>
    ///
    /// </summary>
    class PartialEvaluator : ExpressionVisitor
    {
        private Func<Expression, bool> fnCanBeEvaluated;
        private HashSet<Expression> candidates;
        /// <summary>
        /// 
        /// </summary>
        public PartialEvaluator()
            : this(CanBeEvaluatedLocally)
        {
        }
        /// <summary>
        /// 
        /// </summary>
        /// <param name="fnCanBeEvaluated"></param>
        public PartialEvaluator(Func<Expression, bool> fnCanBeEvaluated)
        {
            this.fnCanBeEvaluated = fnCanBeEvaluated;
        }
        /// <summary>
        /// 
        /// </summary>
        /// <param name="exp"></param>
        /// <returns></returns>
        public Expression Eval(Expression exp)
        {
            this.candidates = new Nominator(this.fnCanBeEvaluated).Nominate(exp);
            return this.Visit(exp);
        }
        /// <summary>
        /// 
        /// </summary>
        /// <param name="exp"></param>
        /// <returns></returns>
        protected override Expression Visit(Expression exp)
        {
            if (exp == null)
                return null;
            if (this.candidates.Contains(exp))
                return this.Evaluate(exp);
            return base.Visit(exp);
        }
        /// <summary>
        /// 
        /// </summary>
        /// <param name="e"></param>
        /// <returns></returns>
        private Expression Evaluate(Expression e)
        {
            if (e.NodeType == ExpressionType.Constant)
                return e;
            LambdaExpression lambda = Expression.Lambda(e);
            Delegate fn = lambda.Compile();
            return Expression.Constant(fn.DynamicInvoke(null), e.Type);
        }
        /// <summary>
        /// 
        /// </summary>
        /// <param name="exp"></param>
        /// <returns></returns>
        private static bool CanBeEvaluatedLocally(Expression exp)
        {
            return exp.NodeType != ExpressionType.Parameter;
        }

        /// <summary>
        /// Performs bottom-up analysis to determine which nodes can possibly
        /// be part of an evaluated sub-tree.
        /// </summary>
        private class Nominator : ExpressionVisitor
        {
            private Func<Expression, bool> m_fnCanBeEvaluated;
            private HashSet<Expression> m_candidates;
            private bool m_cannotBeEvaluated;

            internal Nominator(Func<Expression, bool> fnCanBeEvaluated)
            {
                this.m_fnCanBeEvaluated = fnCanBeEvaluated;
            }

            internal HashSet<Expression> Nominate(Expression expression)
            {
                this.m_candidates = new HashSet<Expression>();
                this.Visit(expression);
                return this.m_candidates;
            }

            protected override Expression Visit(Expression expression)
            {
                if (expression != null)
                {
                    bool saveCannotBeEvaluated = this.m_cannotBeEvaluated;
                    this.m_cannotBeEvaluated = false;

                    base.Visit(expression);

                    if (!this.m_cannotBeEvaluated)
                    {
                        if (this.m_fnCanBeEvaluated(expression))
                        {
                            this.m_candidates.Add(expression);
                        }
                        else
                        {
                            this.m_cannotBeEvaluated = true;
                        }
                    }

                    this.m_cannotBeEvaluated |= saveCannotBeEvaluated;
                }

                return expression;
            }
        }

    }




Guess you like

Origin blog.csdn.net/fuweiping/article/details/78920995