苦Mybatis和SpringData久已,我开始自研Java Orm框架

「这是我参与2022首次更文挑战的第1天,活动详情查看:2022首次更文挑战

数据库层提出原因

提出做数据库层的想法主要有以下几个原因:

  1. Java中已有的orm框架书写起来不够优雅,无论是Mybatis还是Spring-data,都不算特别优雅。
  2. 一直以来想要实现一个类似于Laravel的Orm框架那样,靠点点的方法实现基础的SQL构造。
  3. 建了表就自动生成各种默认的ORM层的函数,而不需要手动实现,只对特殊方法需要手动去写。
  4. 数据库层的存在可以反向规范建表规则,比如,数据库层定义了必须要有创建人,创建时间,更新人,更新时间,删除人,删除时间等等,如果在引入的项目中没有使用这些字段建表,那么就会出错。这样做为数据库建表的反向约束,会大大减少数据库层走弯路的成本,也可以减少数据库层花的时间。
  5. 使得后端整体架构松耦合,让业务层的可以专注于解决业务问题,数据库中数据的转换专门交由数据库层的来做。

架构设计

架构原理

一个好的架构是必要的,这里的实现思路主要是参考了SpringData和Laravel的Orm框架。

SpringData能够兼容多种不同的数据库,也就是说,这个架构需要包含良好的可扩展性,这就要求我们需要使用到继承和多态,Java中实现继承和多态的方式分别是抽象类和接口。
同时,SpringData默认实现了调用了基础方法可以直接使用,这就要求我们的数据库层需要提供基础的数据库操作的方法,甚至一些复杂但是具有普遍性的方法也可以作为默认函数提供出去。

Laravel的Orm框架有个显著的特点,支持通过函数的方式构造出Sql,即

$existFolders = TFolder::where(['folder_name' => $request['folder_name'], 'delete_mark' => false])->get();
复制代码

这种方式个人觉得比较优雅,转换到Java中也就是要实现类似下面这种类似的方法。

query.find("*").where(条件).orderby().page().size();
复制代码

通过点点点带方法的方式实现查询。

总结一下,需要实现具备以下特征的数据库层:

  • 支持不同数据库的扩展
  • 提供大量默认的方法
  • 提供基础的方法,扩展查询方法,能通过点点点的方式实现SQL的创建
  • 具备让业务层直接引入即可使用的情景,也具备单独使用的情景


在实现点点点的方式实现SQL的创建的时候,我参考了Laravel的Orm框架的实现原理,对于每一次查询,其都创建了一个Query作为这次查询的主体,同时的查询使用不同的Query实例,可以避免并发问题。

最终架构用思维导图展示如下:
image.png

定义基础字段

首先,需要定义基础查询的抽象类,用于实现部分基础的方法,同时提供子类自己实现的抽象方法。
那么,需要定义哪些方法呢?首先,需要兼容几种基本的SQL,那么就是CRUD相关内容。

我这里定义了如下基础字段

public abstract class BaseQuery<T> {

    protected String primaryKey = "id";
    protected String table = "";
    protected List<String> with = new ArrayList<>();
    protected List<String> withCount = new ArrayList<>();
    protected Integer limit = -1;
    protected Integer offset = -1;
    protected List<String> traitInitializers = new ArrayList<>();
    protected List<String> globalScopes = new ArrayList<>();
    protected List<String> ignoreOnTouch = new ArrayList<>();
    protected List<String> columns = new ArrayList<>();
    protected WhereSyntaxTree wheres = new WhereSyntaxTree();
    protected Map<String, Object> params = new HashMap<>();

    protected Map<String, Object> updateSetMaps = new HashMap<>();

    protected List<String> orders = new ArrayList<>();
}
复制代码

其中,这里使用了泛型编程,在继承该类的时候需要都加入,这里的T就是指的业务表本身,所以到最终底层类的时候,直接写成即可。

这里的where,我定义了一个递归的结构,因为where语句很有可能是一个递归的,即在生成where语句的时候,很有可能括号里面会有括号。其结构如下:

public class WhereSyntaxTree {

    public Boolean isFinal = false;
    List<WhereSyntaxTree> childTree = new ArrayList<>();
    public WhereSyntaxNode whereSyntaxNode;
}
复制代码

可以看出来,这棵where树是一棵多叉树。其中节点的内容为:

public class WhereSyntaxNode {
    private String name;
    private String operate = "=";
    private Object value;
    private String setName;
    private Boolean valueContainBracket;
    private Boolean listValueIsObject; // 列表数据是否是复杂对象,复杂对象需要做特殊处理
}
复制代码

这里的where树,序列化之后,就是 where a = :a and b = :b。而节点对象里面的setName就是冒号后面的变量名,在默认情况下setName等于name,但当setName已经在前面的节点中出现的时候,就需要重新生成该setName。

这里的BaseQuery本质上是针对关系型数据库实现的基础查询。然后考虑实现一个数据库的查询,以PostgreSql为例,实现PostgreSQLBaseQuery,继承自BaseQuery。其字段定义如下:

public class PostgreSQLBaseQuery<T> extends BaseQuery<T> implements LogExtenseInterface, TreeExtenseInterface, LRUCacheExtensionInterface {

    private Log logger = LogFactory.getLog(PostgreSQLBaseQuery.class);

    private Converter<String, String> camelToUnderscoreConverter = CaseFormat.LOWER_CAMEL.converterTo(CaseFormat.LOWER_UNDERSCORE);

    protected Class<T> clazz;

    public T model;
}
复制代码

这里实现了一些插件,是为了避免在BaseQuery里面写过多的代码,让功能解耦,朋友们可以按照自己的理解去实现日志扩展,缓存扩展以及树形查询扩展。

实现基础查询相关函数

首先,每个基础查询的类都需要构造函数,我们来看看相关的实现。
BaseQuery是抽象类,不需要提供构造函数。
PostgreSQLBaseQuery中引入了泛型类T,需要对其进行初始化,所以其构造函数如下:

public PostgreSQLBaseQuery() {
    Class clazz = getClass();
    while (clazz != Object.class) {
        Type t = clazz.getGenericSuperclass();
        if (t instanceof ParameterizedType) {
            Type[] args = ((ParameterizedType) t).getActualTypeArguments();
            if (args[0] instanceof Class) {
                this.clazz = (Class<T>) args[0];
                break;
            }
        }
    }
    try {
        Constructor constructor = this.clazz.getDeclaredConstructor();
        model = (T) constructor.newInstance();
    } catch (NoSuchMethodException e) {
        e.printStackTrace();
    } catch (IllegalAccessException e) {
        e.printStackTrace();
    } catch (InstantiationException e) {
        e.printStackTrace();
    } catch (InvocationTargetException e) {
        e.printStackTrace();
    }
    this.table = camelToUnderscoreConverter.convert(this.clazz.getSimpleName());
    logger.info(this.table);
}
复制代码

这里获取了当前实例化对象的类,目的是为了后续反序列化的时候需要使用。在HashMap转对象的时候,需要提供对象原本的类。

然后,为了实现点点点的效果,我们需要实现一系列基础的方法,这些方法必须返回值也是Query本身,因为只有这样才能持续让其点点点。
以下方法实现在BaseQuery中:

public BaseQuery<T> finds(List<String> names) {
    this.columns.addAll(names);
    return this;
}

public BaseQuery<T> size(Integer size) {
    this.limit = size;
    return this;
}

public BaseQuery<T> page(Integer page) {
    this.offset = page;
    return this;
}

public BaseQuery<T> orderBy(String key, String order) {
    this.orders.add(key + " " + order);
    return this;
}

public BaseQuery<T> find(String name) {
    this.columns.add(name);
    return this;
}

public AndWhereSyntaxTree defaultAndWheres(Map<String, Object> andWheres) {
    return wheres.createAndTree(andWheres);
}


public AndWhereSyntaxTree defaultAndWheresWithOperate(List<Triplet<String, String, Object>> andWheres) {
    return wheres.createAndTreeByOperate(andWheres);
}

public OrWhereSyntaxTree defaultOrWheresWithOperate(List<Triplet<String, String, Object>> orWheres) {
    return wheres.createOrTreeByOperate(orWheres);
}


public OrWhereSyntaxTree defaultOrWheres(Map<String, Object> orWheres) {
    return wheres.createOrTree(orWheres);
}

public BaseQuery<T> set(String name, Object value) {
    this.updateSetMaps.put(name, value);
    return this;
}

public BaseQuery<T> sets(Map<String, Object> sets) {
    this.updateSetMaps.putAll(sets);
    return this;
}
复制代码

这里面的方法提供了基础的查询和更新所需的方法,分析CRUD可以得知,由于现目前系统都要求是逻辑删除,所以删除操作就只需要实现CRU即可,目前已经提供了RU了。只需要在继承类中组合这些方法实现需求即可。

除了提供的方法,还需要定义一些子类必须实现的抽象方法,如下:

public abstract T simpleGet();

public abstract Long count();

public abstract List<Map<String, Object>> listMapGet();

public abstract Long insert(Map<String, Object> values);

public abstract Integer batchInsert(List<Map<String, Object>> listValues);

public abstract Integer update(Map<String, Object> conditions, Map<String, Object> values);

public abstract Integer update(List<Triplet<String, String, Object>> condition, Map<String, Object> values);

public abstract Integer updateById(Object primaryKey, Map<String, Object> value);
复制代码

这些抽象方法可以根据相关的默认方法需求去增多,这是一个迭代的过程,目前就不去深挖了。

最后,在BaseQuery中还可以定义出默认的SQL生成方法,子类可以重载或者重写。

  • 默认生成更新SQL的函数
protected String defaultGenerateUpdateSql() {
    String sql = "";
    if (updateSetMaps.size() == 0) {
        return "";
    }
    sql = "UPDATE " + this.table + " SET ";
    String setSql = "";
    for (Map.Entry<String, Object> set : this.updateSetMaps.entrySet()) {
        if (setSql.equals("")) {
            setSql = set.getKey() + "=:set" + set.getKey();
            this.params.put("set" + set.getKey(), set.getValue());
        } else {
            setSql = setSql + "," + set.getKey() + "=:set" + set.getKey();
            this.params.put("set" + set.getKey(), set.getValue());
        }
    }
    if (StringUtils.hasText(setSql)) {
        sql = sql + " " + setSql;
    }
    String whereSql = this.wheres.getSql(this.params);
    if (StringUtils.hasText(whereSql)) {
        if (whereSql.startsWith("(") && whereSql.endsWith(")")) {
            whereSql = whereSql.substring(1, whereSql.length() - 1);
        }
        sql = sql + " WHERE " + whereSql + " ";
    }
    sql = sql + ";";
    return sql;
}
复制代码

根据当前变量直接生成更新语句。这里就是为什么我前面提到了,对于每一次查询都是一个独立的Query,如果不是的话,这里的变量势必会存在冲突和多线程问题。解决起来就很麻烦。

  • 默认生成查询SQL的函数
protected String defaultGenerateSql() {
    String sql = "";
    for (String column : this.columns) {
        if (sql.equals("")) {
            sql = "SELECT " + column;
        } else {
            sql = sql + "," + column;
        }
    }
    if (!StringUtils.hasText(sql)) {
        return null;
    } else {
        sql = sql + " ";
    }
    sql = sql + "FROM " + table + " ";
    String whereSql = this.wheres.getSql(this.params);
    if (StringUtils.hasText(whereSql)) {
        if (whereSql.startsWith("(") && whereSql.endsWith(")")) {
            whereSql = whereSql.substring(1, whereSql.length() - 1);
        }
        sql = sql + "WHERE " + whereSql + " ";
    }
    // TODO: 2021/8/17 order by
    if (this.orders.size() > 0) {
        String orderSql = "";
        for (String order : this.orders) {
            if (orderSql.equals("")) {
                orderSql = "order by " + order;
            } else {
                orderSql = orderSql + "," + order;
            }
        }
        sql = sql + orderSql + " ";
    }
    // TODO: 2021/8/17 offset
    if (this.offset != -1) {
        if (this.limit != -1) {
            sql = sql + " offset " + (this.offset - 1) * this.limit + " ";
        } else {
            sql = sql + " offset " + this.offset + " ";
        }
    }
    if (this.limit != -1) {
        sql = sql + " limit " + this.limit + " ";
    }
    sql = sql + ";";
    return sql;
}
复制代码

好了,我们的基础查询所具备的方法差不多就是这些了。

where多叉树如何生成SQL

不知道你是否还记得,我们在定义基础查询类的where的时候,是一个多叉树的类型。那么为啥是多叉树呢?我用一张图来解释一下吧:
image.png
那么,我们生成全量的where语句的方法就很明显了,只需要按照多叉树的深度优先遍历方式打印出每个节点即可。代码如下:

public String getSql(Map<String, Object> params) {
    if (isFinal) {
        if (params.containsKey(whereSyntaxNode.getSetName())) {
            Random random = new Random();
            whereSyntaxNode.setSetName(MD5Utils.compMd5(whereSyntaxNode.getSetName() + LocalDateTime.now().toString() + random.ints().toString()));
        }
        params.put(whereSyntaxNode.getSetName(), whereSyntaxNode.getValue());
        if (whereSyntaxNode.getValueContainBracket()) {
            return whereSyntaxNode.getName() + " " + whereSyntaxNode.getOperate() + " (:" + whereSyntaxNode.getSetName() + ")";
        } else {
            return whereSyntaxNode.getName() + " " + whereSyntaxNode.getOperate() + " :" + whereSyntaxNode.getSetName();
        }
    } else {
        String sunSql = "";
        for (WhereSyntaxTree whereSyntaxTree : childTree) {
            if (whereSyntaxTree instanceof AndWhereSyntaxTree) {
                if (sunSql.equals("")) {
                    sunSql = whereSyntaxTree.getSql(params);
                } else {
                    sunSql = sunSql + " AND " + whereSyntaxTree.getSql(params);
                }
            } else if (whereSyntaxTree instanceof OrWhereSyntaxTree) {
                if (sunSql.equals("")) {
                    sunSql = whereSyntaxTree.getSql(params);
                } else {
                    sunSql = sunSql + " OR " + whereSyntaxTree.getSql(params);
                }
            }
        }
        if (StringUtils.hasText(sunSql)) {
            return "(" + sunSql + ")";
        } else {
            return "";
        }
    }
}
复制代码

PostgreSql子查询类相关函数

首先,是需要实现基础查询留下的抽象方法

@Override
public T simpleGet() {
    String sql = defaultGenerateSql();
    logger.info(sql);
    logger.info(this.params);
    try {
        T xx = SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).queryForObject(sql, this.params, new BeanPropertyRowMapper<>(this.clazz));
        return xx;
    } catch (Exception e) {
    }
    return null;
}

@Override
public Long count() {
    String findStr = "count(" + primaryKey + ")";
    this.find(findStr);
    String sql = defaultGenerateSql();
    logger.info(sql);
    logger.info(this.params);
    Long ans = SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).queryForObject(sql, this.params, Long.class);
    return ans;
}

@Override
public List<Map<String, Object>> listMapGet() {
    String sql = defaultGenerateSql();
    logger.info(sql);
    logger.info(this.params);
    return SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).queryForList(sql, this.params);
}

@Override
public Long insert(Map<String, Object> values) {
    if (Objects.isNull(values)) values = new HashMap<>();
    values = removeNull(values);
    String names = "";
    String nameParams = "";
    for (Map.Entry<String, Object> tmp : values.entrySet()) {
        if (names.equals("")) {
            names = names + tmp.getKey();
            nameParams = nameParams + ":" + tmp.getKey();
        } else {
            names = names + "," + tmp.getKey();
            nameParams = nameParams + ",:" + tmp.getKey();
        }
    }
    this.params.putAll(values);
    params = convertParams(params);
    String sql = "INSERT INTO " + this.table + "(" + names + ") VALUES (" + nameParams + ") RETURNING " + primaryKey + ";";
    logger.info(sql);
    logger.info(this.params);
    Long id = SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).queryForObject(sql, this.params, Long.class);
    return id;
}

@Override
public Integer batchInsert(List<Map<String, Object>> listValues) {
    if (Objects.isNull(listValues)) listValues = new ArrayList<>();
    int cnt = 0;
    String insertNames = "";
    List<String> insertNameParams = new ArrayList<>();
    int n = listValues.size();
    for (Map<String, Object> values : listValues) {
        values = removeNull(values);
        values = convertParams(values);
        String names = "";
        String nameParams = "";
        for (Map.Entry<String, Object> tmp : values.entrySet()) {
            if (names.equals("")) {
                names = names + tmp.getKey();
                nameParams = nameParams + ":" + tmp.getKey() + cnt;
            } else {
                names = names + "," + tmp.getKey();
                nameParams = nameParams + ",:" + tmp.getKey() + cnt;
            }
            this.params.put(tmp.getKey() + "" + cnt, tmp.getValue());
        }
        insertNames = names;
        nameParams = "(" + nameParams + ")";
        insertNameParams.add(nameParams);
        cnt++;
    }
    String sql = "INSERT INTO " + this.table + "(" + insertNames + ") VALUES " + ArrayStrUtil.slist2Str(insertNameParams, ",") + ";";
    logger.info(sql);
    logger.info(this.params);
    int x = SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).update(sql, this.params);
    return x;
}


@Override
public Integer update(Map<String, Object> conditions, Map<String, Object> values) {
    if (Objects.isNull(conditions)) conditions = new HashMap<>();
    if (Objects.isNull(values)) values = new HashMap<>();
    conditions = removeNull(conditions);
    values = removeNull(values);
    this.updateSetMaps.putAll(values);
    WhereSyntaxTree whereSyntaxTree = defaultAndWheres(conditions);
    this.where(whereSyntaxTree);
    String sql = defaultGenerateUpdateSql();
    params = convertParams(params);
    logger.info(sql);
    logger.info(params);
    Integer influenceNumber = SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).update(sql, this.params);
    return influenceNumber;
}

@Override
public Integer update(List<Triplet<String, String, Object>> condition, Map<String, Object> values) {
    if (Objects.isNull(condition)) condition = new ArrayList<>();
    if (Objects.isNull(values)) values = new HashMap<>();
    values = removeNull(values);
    this.updateSetMaps.putAll(values);
    WhereSyntaxTree whereSyntaxTree = defaultAndWheresWithOperate(condition);
    this.where(whereSyntaxTree);
    String sql = defaultGenerateUpdateSql();
    params = convertParams(params);
    logger.info(sql);
    logger.info(params);
    Integer influenceNumber = SpringContextUtil.getBean(NamedParameterJdbcTemplate.class).update(sql, this.params);
    return influenceNumber;
}


private Map<String, Object> convertParams(Map<String, Object> params) {
    Map<String, Object> newParams = new HashMap<>();
    for (Map.Entry<String, Object> param : params.entrySet()) {
        try {
            newParams.put(param.getKey(), DatetimeUtil.getLocalDatetimeByStr((String) param.getValue()));
        } catch (Exception e) {
            newParams.put(param.getKey(), param.getValue());
        }
    }
    return newParams;
}

@Override
public Integer updateById(Object primaryKey, Map<String, Object> values) {
    Map<String, Object> conditions = new HashMap<>();
    conditions.put(this.primaryKey, primaryKey);
    return this.update(conditions, values);
}

@Override
public BaseQuery<T> where(WhereSyntaxTree whereSyntaxTree) {
    this.wheres = whereSyntaxTree;
    return this;
}
复制代码

然后再定义一些复杂查询,留给业务查询来调用

/**
     * select * from a where id = 2 ......
     *
     * @param id
     * @return
     */
public T findModelById(Object id) {
    return this.find("*").findById(id).simpleGet();
}

/**
     * select aa,aaa,aaaa from a where id = 2 ....
     *
     * @param id
     * @param fields
     * @return
     */
protected T findModelById(Object id, List<String> fields) {
    return this.finds(fields).findById(id).simpleGet();
}

/**
     * select * from a where (x=1 and y=2) and delete_mark = false ......
     *
     * @param andCondition
     * @return
     */
public T findModelBySimpleAnd(Map<String, Object> andCondition) {
    if (Objects.isNull(andCondition)) andCondition = new HashMap<>();
    andCondition = removeNull(andCondition);
    andCondition.put("deleted_mark", false);
    AndWhereSyntaxTree andWhereSyntaxTree = this.defaultAndWheres(andCondition);
    return this.find("*").where(andWhereSyntaxTree).orderBy(primaryKey, "desc").size(1).simpleGet();
}

/**
     * select * from a where (x=1 and y=2) and delete_mark = true ......
     *
     * @param andCondition
     * @return
     */
public T findModelBySimpleAndDeletedMarkTrue(Map<String, Object> andCondition) {
    if (Objects.isNull(andCondition)) andCondition = new HashMap<>();
    andCondition = removeNull(andCondition);
    andCondition.put("deleted_mark", true);
    AndWhereSyntaxTree andWhereSyntaxTree = this.defaultAndWheres(andCondition);
    return this.find("*").where(andWhereSyntaxTree).orderBy(primaryKey, "desc").size(1).simpleGet();
}

/**
     * select * from a where (x = 1 or y = 2) and delete_mark = false .....
     *
     * @param orCondition
     * @return
     */
public T findModelBySimpleOr(Map<String, Object> orCondition) {
    if (Objects.isNull(orCondition)) orCondition = new HashMap<>();
    orCondition = removeNull(orCondition);
    OrWhereSyntaxTree orWhereSyntaxTree = this.defaultOrWheres(orCondition);
    Map<String, Object> andWhereCondition = new HashMap<>();
    andWhereCondition.put("deleted_mark", false);
    andWhereCondition.put(MD5Utils.compMd5(orWhereSyntaxTree.toString() + LocalDateTime.now().toString()), orWhereSyntaxTree);
    AndWhereSyntaxTree andWhereSyntaxTree = this.defaultAndWheres(andWhereCondition);
    return this.find("*").where(andWhereSyntaxTree).size(1).simpleGet();
}
复制代码

你还可以定义更多的方法供后面的子类来使用。

猜你喜欢

转载自juejin.im/post/7055877832261517343