前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Java | 实现一个ORM比你想象的还要简单

Java | 实现一个ORM比你想象的还要简单

作者头像
双鬼带单
发布2021-03-18 21:21:25
1.4K0
发布2021-03-18 21:21:25
举报
文章被收录于专栏:CodingToDieCodingToDie

实现一个 ORM 到底多简单


Table of Contents

原理ORM 实现1. 通过注解来将 Java Bean 和数据库字段关联2. 反射工具类3. 简单的 model 示例4. 注解解析5. 数据库操作6. 结合反射实现查询操作使用动态代理实现 @Query @Select 类似功能1. 动态代理2. 注解3. 表设计4. model5. repository7. 大体流程8. 代理使用9. 将生成代理放入 Spring IOC 容器中10. invoke方法处理


在这篇文章中主要用到了注解、反射、动态代理、正则、Jexl3 表达式来实现一个 ORM, 在大多数框架中,都会用到这些东西,在这里只是简单的使用一下,而不是原理分析

原理

在使用的 ORM 框架中,我可以像操作对象一样操作数据的存储,这是怎么实现的,我们知道数据库是认识 SQL 语句的,但并不认识java bean 呀!同时我们在使用ORM时,需要根据ORM框架的规定定义我们的bean,这是为什么?

这是因为 ORM 为我们提供了将对象操作转化为对应的 SQL语句,例如 save(bean), 这时就需要转化成一个 insert 语句,update(bean) 这时就需要转成成对应的 update 语句

通常 insert 语句格式为

代码语言:javascript
复制
1insert into 表名 (列名) values( 值)

update 语句为

代码语言:javascript
复制
1update 表名 set 列名 = 值 where id = 值

上面的格式可以看出,如果我们能从对象中得出 表名 列名 ,我们也可以写一个简单的ORM框架

ORM 实现

1. 通过注解来将 Java Bean 和数据库字段关联

上篇文章中提到了一下注解,以及自定义注解和解析注解的方法,通过使用注解,我们可以完成一个简单的ORM

要想实现对数据库的操作,我们必须知道数据表名以及表中的字段名称以及类型,正如hibernate 使用注解标识 model 与数据库的映射关系一样,这里我使用了Java Persistence API

注解说明

注解

作用

使用说明

@Entity

标记这是一个实体

标注在类上,标明该类可以被 ORM 处理

@Table

标记实体对应的表

标注在类上,标明该类对应的数据库标明

@Id

标记该字段是id

标注在字段上,标明该字段为 id

@Column

标记该字段对应的列信息

标记在字段上,标明对应的列信息,主要对应列名

字段属性表 用来存储对象中字段与数据表列的对应关系

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.core;
 2
 3import lombok.Data;
 4
 5import javax.persistence.GenerationType;
 6
 7/**
 8 * The type Db column info.
 9 *
10 * @author 张瑀楠 zyndev@gmail.com
11 * @version 0.0.1
12 */
13@Data
14class DBColumnInfo {
15
16    /**
17     * (Optional) The primary key generation strategy
18     * that the persistence provider must use to
19     * generate the annotated entity primary key.
20     */
21    private GenerationType strategy = GenerationType.AUTO;
22
23    private String fieldName;
24
25    /**
26     * The name of the column
27     */
28    private String columnName;
29
30    /**
31     * Whether the column is a unique key.
32     */
33    private boolean unique;
34
35    /**
36     * Whether the database column is nullable.
37     */
38    private boolean nullable = true;
39
40    /**
41     * Whether the column is included in SQL INSERT
42     */
43    private boolean insertAble = true;
44
45    /**
46     * Whether the column is included in SQL UPDATE
47     */
48    private boolean updatable = true;
49
50    /**
51     * The SQL fragment that is used when
52     * generating the DDL for the column.
53     */
54    private String columnDefinition;
55
56    /**
57     * The name of the table that contains the column.
58     * If absent the column is assumed to be in the primary table.
59     */
60    private String table;
61
62    /**
63     * (Optional) The column length. (Applies only if a
64     * string-valued column is used.)
65     */
66    private int length =  255;
67
68    private boolean id = false;
69
70}

2. 反射工具类

提供一些常用的反射操作

通过反射我们可以动态的得到一个类所有的成员变量的信息,同时为这些变量取值或者赋值

代码语言:javascript
复制
  1package com.zyndev.tool.fastsql.util;
  2
  3
  4import java.lang.reflect.Field;
  5import java.lang.reflect.Method;
  6import java.util.ArrayList;
  7import java.util.HashMap;
  8import java.util.List;
  9import java.util.Map;
 10
 11
 12/**
 13 * The type Bean reflection util.
 14 *
 15 * @author yunan.zhang zyndev@gmail.com
 16 * @version 0.0.3
 17 * @date 2017年12月26日 16点36分
 18 */
 19public class BeanReflectionUtil {
 20
 21    public static Object getPrivatePropertyValue(Object obj,String propertyName)throws Exception{
 22        Class cls=obj.getClass();
 23        Field field=cls.getDeclaredField(propertyName);
 24        field.setAccessible(true);
 25        Object retvalue=field.get(obj);
 26        return retvalue;
 27    }
 28
 29    /**
 30     * Gets static field value.
 31     */
 32    public static Object getStaticFieldValue(String className, String fieldName) throws Exception {
 33        Class cls = Class.forName(className);
 34        Field field = cls.getField(fieldName);
 35        return field.get(cls);
 36    }
 37
 38    /**
 39     * Gets field value.
 40     */
 41    public static Object getFieldValue(Object obj, String fieldName) throws Exception {
 42        Class cls = obj.getClass();
 43        Field field = cls.getDeclaredField(fieldName);
 44        field.setAccessible(true);
 45        return field.get(obj);
 46    }
 47
 48    /**
 49     * Invoke method object.
 50     */
 51    public Object invokeMethod(Object owner, String methodName, Object[] args) throws Exception {
 52        Class cls = owner.getClass();
 53        Class[] argclass = new Class[args.length];
 54        for (int i = 0, j = argclass.length; i < j; i++) {
 55            argclass[i] = args[i].getClass();
 56        }
 57        @SuppressWarnings("unchecked")
 58        Method method = cls.getMethod(methodName, argclass);
 59        return method.invoke(owner, args);
 60    }
 61
 62    /**
 63     * Invoke static method object.
 64     */
 65    public Object invokeStaticMethod(String className, String methodName, Object[] args) throws Exception {
 66        Class cls = Class.forName(className);
 67        Class[] argClass = new Class[args.length];
 68        for (int i = 0, j = argClass.length; i < j; i++) {
 69            argClass[i] = args[i].getClass();
 70        }
 71        @SuppressWarnings("unchecked")
 72        Method method = cls.getMethod(methodName, argClass);
 73        return method.invoke(null, args);
 74    }
 75
 76    /**
 77     * New instance object.
 78     */
 79    public static Object newInstance(String className) throws Exception {
 80        Class clazz = Class.forName(className);
 81        @SuppressWarnings("unchecked")
 82        java.lang.reflect.Constructor cons = clazz.getConstructor();
 83        return cons.newInstance();
 84    }
 85
 86    /**
 87     * New instance object.
 88     */
 89    public static Object newInstance(Class clazz) throws Exception {
 90        @SuppressWarnings("unchecked")
 91        java.lang.reflect.Constructor cons = clazz.getConstructor();
 92        return cons.newInstance();
 93    }
 94
 95    /**
 96     * Get bean declared fields field [ ].
 97     */
 98    public static Field[] getBeanDeclaredFields(String className) throws ClassNotFoundException {
 99        Class clazz = Class.forName(className);
100        return clazz.getDeclaredFields();
101    }
102
103    /**
104     * Get bean declared methods method [ ].
105     */
106    public static Method[] getBeanDeclaredMethods(String className) throws ClassNotFoundException {
107        Class clazz = Class.forName(className);
108        return clazz.getMethods();
109    }
110
111    /**
112     * Copy fields.
113     */
114    public static void copyFields(Object source, Object target) {
115        try {
116            List<Field> list = new ArrayList<>();
117            Field[] sourceFields = getBeanDeclaredFields(source.getClass().getName());
118            Field[] targetFields = getBeanDeclaredFields(target.getClass().getName());
119            Map<String, Field> map = new HashMap<>(targetFields.length);
120            for (Field field : targetFields) {
121                map.put(field.getName(), field);
122            }
123            for (Field field : sourceFields) {
124                if (map.get(field.getName()) != null) {
125                    list.add(field);
126                }
127            }
128            for (Field field : list) {
129                Field tg = map.get(field.getName());
130                tg.setAccessible(true);
131                tg.set(target, getFieldValue(source, field.getName()));
132            }
133        } catch (Exception e) {
134            e.printStackTrace();
135        }
136    }
137}

3. 简单的 model 示例

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql;
 2
 3import javax.persistence.*;
 4
 5/**
 6 * @author 张瑀楠 zyndev@gmail.com
 7 * @date   2017/11/30 下午11:21
 8 */
 9@Data
10@Entity
11@Table(name = "tb_student")
12public class Student {
13
14    @Id
15    @Column
16    private Integer id;
17
18    @Column
19    private String name;
20
21    @Column(updatable = true, insertable = true, nullable = false)
22    private Integer age;
23
24}

4. 注解解析

将对象的上的注解进行解析,得到对应关系

代码语言:javascript
复制
  1package com.zyndev.tool.fastsql.core;
  2
  3
  4import com.zyndev.tool.fastsql.util.StringUtil;
  5
  6import javax.persistence.Column;
  7import javax.persistence.Id;
  8import javax.persistence.Table;
  9import java.lang.reflect.Field;
 10import java.util.ArrayList;
 11import java.util.HashMap;
 12import java.util.List;
 13import java.util.Map;
 14
 15
 16/**
 17 * The type Annotation parser.
 18 * <p>
 19 * 注解解析工具
 20 *
 21 * @author yunan.zhang zyndev@gmail.com
 22 * @version 0.0.3
 23 * @since 2017-12-26 16:59:07
 24 */
 25public class AnnotationParser {
 26
 27    /**
 28     * 存储类名和数据库表名的关系
 29     * 使用三个cache 主要为了减少反射的次数,提高效率
 30     */
 31    private final static Map<String, String> tableNameCache = new HashMap<>(30);
 32
 33    /**
 34     * 存储类名和数据库列的关联关系
 35     */
 36    private final static Map<String, String> tableAllColumnNameCache = new HashMap<>(30);
 37
 38    /**
 39     * 存储类名和对应的数据库全列名的关系
 40     */
 41    private final static Map<String, List<DBColumnInfo>> tableAllDBColumnCache = new HashMap<>(30);
 42
 43    /**
 44     * Gets table name.
 45     * 获得表名
 46     * 判断是否有@Table注解,如果有得到注解的name,如果name为空,则使用类名做为表名
 47     * 如果没有@Table返回null
 48     *
 49     * @param <E>    the type parameter
 50     * @param entity the entity
 51     * @return the table name
 52     */
 53    public static <E> String getTableName(E entity) {
 54        String tableName = tableNameCache.get(entity.getClass().getName());
 55        if (tableName == null) {
 56            Table table = entity.getClass().getAnnotation(Table.class);
 57            if (table != null && StringUtil.isNotBlank(table.name())) {
 58                tableName = table.name();
 59            } else {
 60                tableName = entity.getClass().getSimpleName();
 61            }
 62            tableNameCache.put(entity.getClass().getName(), tableName);
 63        }
 64        return tableName;
 65    }
 66
 67    /**
 68     * Gets all db column info.
 69     */
 70    public static <E> List<DBColumnInfo> getAllDBColumnInfo(E entity) {
 71        List<DBColumnInfo> dbColumnInfoList = tableAllDBColumnCache.get(entity.getClass().getName());
 72        if (dbColumnInfoList == null) {
 73            dbColumnInfoList = new ArrayList<>();
 74            DBColumnInfo dbColumnInfo;
 75            Field[] fields = entity.getClass().getDeclaredFields();
 76            for (Field field : fields) {
 77                Column column = field.getAnnotation(Column.class);
 78                if (column != null) {
 79                    dbColumnInfo = new DBColumnInfo();
 80                    if (StringUtil.isBlank(column.name())) {
 81                        dbColumnInfo.setColumnName(field.getName());
 82                    } else {
 83                        dbColumnInfo.setColumnName(column.name());
 84                    }
 85                    if (null != field.getAnnotation(Id.class)) {
 86                        dbColumnInfo.setId(true);
 87                    }
 88                    dbColumnInfo.setFieldName(field.getName());
 89                    dbColumnInfoList.add(dbColumnInfo);
 90                }
 91            }
 92            tableAllDBColumnCache.put(entity.getClass().getName(), dbColumnInfoList);
 93        }
 94        return dbColumnInfoList;
 95    }
 96
 97    /**
 98     * 返回表字段的所有字段  column1,column2,column3
 99     *
100     * @param <E>    the type parameter
101     * @param entity the entity
102     * @return string
103     */
104    public static <E> String getTableAllColumn(E entity) {
105
106        String allColumn = tableAllColumnNameCache.get(entity.getClass().getName());
107        if (allColumn == null) {
108            List<DBColumnInfo> dbColumnInfoList = getAllDBColumnInfo(entity);
109            StringBuilder allColumnsInfo = new StringBuilder();
110            int i = 1;
111            for (DBColumnInfo dbColumnInfo : dbColumnInfoList) {
112                allColumnsInfo.append(dbColumnInfo.getColumnName());
113                if (i != dbColumnInfoList.size()) {
114                    allColumnsInfo.append(",");
115                }
116                i++;
117            }
118            allColumn = allColumnsInfo.toString();
119            tableAllColumnNameCache.put(entity.getClass().getName(), allColumn);
120        }
121        return allColumn;
122
123    }
124}

5. 数据库操作

数据库交互使用spring 提供的 JdbcTemplate,这里没有在自己写一套 DBUtil

6. 结合反射实现查询操作

保存一个entity

保存操作相对简单,这里主要是将 entity 转换为 insert 语句

代码语言:javascript
复制
 1/**
 2 * Save int.
 3 * @param entity the entity
 4 * @return the int
 5 */
 6@Override
 7public int save(Object entity) {
 8    try {
 9        String tableName = AnnotationParser.getTableName(entity);
10        StringBuilder property = new StringBuilder();
11        StringBuilder value = new StringBuilder();
12        List<Object> propertyValue = new ArrayList<>();
13        List<DBColumnInfo> dbColumnInfoList = AnnotationParser.getAllDBColumnInfo(entity);
14
15        for (DBColumnInfo dbColumnInfo : dbColumnInfoList) {
16            if (dbColumnInfo.isId() || !dbColumnInfo.isInsertAble()) {
17                continue;
18            }
19            // 不为null
20            Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
21            if (o != null) {
22                property.append(",").append(dbColumnInfo.getColumnName());
23                value.append(",").append("?");
24                propertyValue.add(o);
25            }
26        }
27        String sql = "insert into " + tableName + "(" + property.toString().substring(1) + ") values(" + value.toString().substring(1) + ")";
28        return this.getJdbcTemplate().update(sql, propertyValue.toArray());
29    } catch (Exception e) {
30        e.printStackTrace();
31    }
32    return 0;
33}

更新操作

更新操作相对于 保存来说,多了一步确定where 语句操作

代码语言:javascript
复制
 1/**
 2 * Update int.
 3 *
 4 * @param entity     the entity
 5 * @param ignoreNull the ignore null
 6 * @param columns    the columns
 7 * @return the int
 8 */
 9@Override
10public int update(Object entity, boolean ignoreNull, String... columns) {
11    try {
12        String tableName = AnnotationParser.getTableName(entity);
13        StringBuilder property = new StringBuilder();
14        StringBuilder where = new StringBuilder();
15        List<Object> propertyValue = new ArrayList<>();
16        List<Object> wherePropertyValue = new ArrayList<>();
17        List<DBColumnInfo> dbColumnInfos = AnnotationParser.getAllDBColumnInfo(entity);
18        for (DBColumnInfo dbColumnInfo : dbColumnInfos) {
19
20            Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
21            if (dbColumnInfo.isId()) {
22                where.append(" and ").append(dbColumnInfo.getColumnName()).append(" = ? ");
23                wherePropertyValue.add(o);
24            } else if (ignoreNull || o != null) {
25                property.append(",").append(dbColumnInfo.getColumnName()).append("=?");
26                propertyValue.add(o);
27            }
28        }
29
30        if (wherePropertyValue.isEmpty()) {
31            throw new IllegalArgumentException("更新表 [" + tableName + "] 无法找到id, 请求数据:" + entity);
32        }
33
34        String sql = "update " + tableName + " set " + property.toString().substring(1) + " where " + where.toString().substring(5);
35        propertyValue.addAll(wherePropertyValue);
36        return this.getJdbcTemplate().update(sql, propertyValue.toArray());
37    } catch (Exception e) {
38        e.printStackTrace();
39    }
40    return 0;
41}

删除操作

相对 update 简单一点

代码语言:javascript
复制
 1/**
 2 * Delete int.
 3 * <p>根据id 删除对应的数据</p>
 4 *
 5 * @param entity the entity
 6 * @return the int
 7 */
 8@Override
 9public int delete(Object entity) {
10    try {
11        String tableName = AnnotationParser.getTableName(entity);
12        StringBuilder where = new StringBuilder(" 1=1 ");
13        List<Object> whereValue = new ArrayList<>(5);
14        List<DBColumnInfo> dbColumnInfos = AnnotationParser.getAllDBColumnInfo(entity);
15        for (DBColumnInfo dbColumnInfo : dbColumnInfos) {
16            if (dbColumnInfo.isId()) {
17                Object o = BeanReflectionUtil.getFieldValue(entity, dbColumnInfo.getFieldName());
18                if (null != o) {
19                    whereValue.add(o);
20                }
21                where.append(" and `").append(dbColumnInfo.getColumnName()).append("` = ? ");
22            }
23        }
24
25        if (whereValue.size() == 0) {
26            throw new IllegalStateException("delete " + tableName + " id 无对应值,不能删除");
27        }
28        String sql = "delete from  " + tableName + " where " + where.toString();
29        return this.getJdbcTemplate().update(sql, whereValue);
30    } catch (Exception e) {
31        e.printStackTrace();
32    }
33    return 0;
34}

通过上面的示例,就可以简单的实现一个 ORM, 为了更好的使用,我们还需要提供自己写 SQL 的方案

使用动态代理实现 @Query @Select 类似功能

1. 动态代理

这里直接使用基于 JDK 动态代理实现

2. 注解

Java Persistence API 中没有我们需要的 @Query 和 @Param 这里我们自定义一下这两个注解,同时为了让 Query 支持返回主键,在定义一个 ReturnGeneratedKey 注解

Query.java

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.annotation;
 2
 3import java.lang.annotation.ElementType;
 4import java.lang.annotation.Retention;
 5import java.lang.annotation.RetentionPolicy;
 6import java.lang.annotation.Target;
 7
 8/**
 9 * 查询操作 sql
10 *
11 * @author 张瑀楠 zyndev@gmail.com
12 * @version 0.0.1
13 * @since  2017/12/22 17:26
14 */
15@Target({ElementType.METHOD})
16@Retention(RetentionPolicy.RUNTIME)
17public @interface Query {
18
19    /**
20     * sql 语句
21     */
22    String value();
23}

Param.java

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.annotation;
 2
 3import java.lang.annotation.ElementType;
 4import java.lang.annotation.Retention;
 5import java.lang.annotation.RetentionPolicy;
 6import java.lang.annotation.Target;
 7
 8/**
 9 *
10 * @author 张瑀楠 zyndev@gmail.com
11 * @version 0.0.1
12 * @since  2017/12/22 17:29
13 */
14@Target( { ElementType.PARAMETER })
15@Retention(RetentionPolicy.RUNTIME)
16public @interface Param {
17
18    /**
19     * 指出这个值是 SQL 语句中哪个参数的值,使用命名参数
20     */
21    String value();
22}

ReturnGeneratedKey.java

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.annotation;
 2
 3
 4import java.lang.annotation.*;
 5
 6
 7/**
 8 * 返回主键
 9 *
10 * @author 张瑀楠 zyndev@gmail.com
11 * @version 0.0.3
12 *
13 */
14@Target({ElementType.METHOD})
15@Retention(RetentionPolicy.RUNTIME)
16@Documented
17public @interface ReturnGeneratedKey {
18}

3. 表设计

代码语言:javascript
复制
 1CREATE TABLE `tb_user` (
 2  `id` int(11) NOT NULL AUTO_INCREMENT,
 3  `uid` varchar(40) DEFAULT NULL,
 4  `account_name` varchar(40) DEFAULT NULL,
 5  `nick_name` varchar(23) DEFAULT NULL,
 6  `password` varchar(30) DEFAULT NULL,
 7  `phone` varchar(16) DEFAULT NULL,
 8  `register_time` timestamp NULL DEFAULT NULL,
 9  `update_time` timestamp NULL DEFAULT NULL,
10  PRIMARY KEY (`id`)
11) ENGINE=InnoDB AUTO_INCREMENT=3 DEFAULT CHARSET=utf8

4. model

这里使用一个 User.java 作为例子:

User.java

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.repository;
 2
 3import lombok.*;
 4
 5import java.io.Serializable;
 6import java.util.Date;
 7
 8/**
 9 * @author 张瑀楠 zyndev@gmail.com
10 * @version 1.0
11 * @date 2017-12-27 15:04:13
12 */
13@Data
14public class User implements Serializable {
15
16    private static final long serialVersionUID = 1L;
17
18    private Integer id;
19
20    private String uid;
21
22    private String accountName;
23
24    private String nickName;
25
26    private String password;
27
28    private String phone;
29
30    private Date registerTime;
31
32    private Date updateTime;
33
34}

5. repository

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.repository;
 2
 3import com.zyndev.tool.fastsql.annotation.Param;
 4import com.zyndev.tool.fastsql.annotation.Query;
 5import com.zyndev.tool.fastsql.annotation.ReturnGeneratedKey;
 6import org.springframework.stereotype.Repository;
 7
 8import java.util.List;
 9import java.util.Map;
10
11/**
12 * 这里应该有描述
13 *
14 * @version 1.0
15 * @author 张瑀楠 zyndev@gmail.com
16 * @date 2017 /12/22 18:13
17 */
18@Repository
19public interface UserRepository {
20
21    @Query("select count(*) from tb_user")
22    public Integer getCount();
23
24    @Query("delete from tb_user where id = ?1")
25    public Boolean deleteById(int id);
26
27    @Query("select count(*) from tb_user where password = ?1 ")
28    public int getCountByPassword(@Param("password") String password);
29
30    @Query("select uid from tb_user where password = ?1 ")
31    public String getUidByPassword(@Param("password") String password);
32
33    @Query("select * from tb_user where id = :id ")
34    public User getUserById(@Param("id") Integer id);
35
36    @Query("select * " +
37            " from tb_user " +
38            " where account_name = :accountName ")
39    public List<User> getUserByAccountName(@Param("accountName") String accountName);
40
41    @Query("insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time) " +
42            "values(:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )")
43    public int saveUser(@Param("id") Integer id, @Param("user") User user);
44
45    @ReturnGeneratedKey
46    @Query("insert into tb_user(account_name, password, uid, nick_name, register_time, update_time) " +
47            "values(:user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )")
48    public int saveUser(@Param("user") User user);
49}

7. 大体流程

在上面的我们已经完成一些准备工作,包括:

  1. 注解的定义
  2. 表的设计
  3. model 的设计
  4. Repository 的设计

接下来,我们看看如何将这些整合在一起

大致流程:

  1. 为 Repository 生成代理
  2. 将生成代理放入 Spring IOC 容器中
  3. 当代理的方法被调用时,得到方法的 @Query , @Param,@ReturnGeneratedKey 注解,并取得方法的返回值
  4. 重写 Query的sql,并执行,根据方法的返回类型,封装SQL返回结果集

8. 代理使用

FacadeProxy.java

为 Repository 生成代理,当代理方法执行时,回调 invoke 方法,invoke 中逻辑写到StatementParser.java中,防止类功能过大

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.core;
 2
 3import org.apache.commons.logging.Log;
 4import org.apache.commons.logging.LogFactory;
 5
 6import java.lang.reflect.InvocationHandler;
 7import java.lang.reflect.Method;
 8import java.lang.reflect.Proxy;
 9
10/**
11 * 生成代理
12 *
13 * @author 张瑀楠 zyndev@gmail.com
14 * @version 0.0.1
15 * @since  2017 /12/23 上午12:40
16 */
17@SuppressWarnings("unchecked")
18public class FacadeProxy implements InvocationHandler {
19
20    private final static Log logger = LogFactory.getLog(FacadeProxy.class);
21
22    @Override
23    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
24        return StatementParser.invoke(proxy, method, args);
25    }
26
27    /**
28     * New mapper proxy t.
29     */
30    protected static <T> T newMapperProxy(Class<T> mapperInterface) {
31        logger.info(" 生成代理:" + mapperInterface.getName());
32        ClassLoader classLoader = mapperInterface.getClassLoader();
33        Class<?>[] interfaces = new Class[]{mapperInterface};
34        FacadeProxy proxy = new FacadeProxy();
35        return (T) Proxy.newProxyInstance(classLoader, interfaces, proxy);
36    }
37}

9. 将生成代理放入 Spring IOC 容器中

在这里使用了 BeanFactoryAware,关于这部分内容会单独写一篇,这里不在详细说明

代码语言:javascript
复制
 1package com.zyndev.tool.fastsql.core;
 2
 3import com.zyndev.tool.fastsql.util.ClassScanner;
 4import com.zyndev.tool.fastsql.util.StringUtil;
 5import org.springframework.beans.BeansException;
 6import org.springframework.beans.factory.BeanFactory;
 7import org.springframework.beans.factory.BeanFactoryAware;
 8import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
 9import org.springframework.beans.factory.support.BeanDefinitionRegistry;
10import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
11import org.springframework.core.type.AnnotationMetadata;
12import org.springframework.stereotype.Repository;
13
14import java.io.IOException;
15import java.util.Set;
16
17/**
18 * The type Fast sql repository registrar.
19 *
20 * @author 张瑀楠 zyndev@gmail.com
21 * @version 1.0
22 * @date 2018 /2/23 12:26
23 */
24public class FastSqlRepositoryRegistrar implements ImportBeanDefinitionRegistrar, BeanFactoryAware {
25
26    private ConfigurableListableBeanFactory beanFactory;
27
28    @Override
29    public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry beanDefinitionRegistry) {
30        System.out.println("FastSqlRepositoryRegistrar registerBeanDefinitions ");
31        String basePackage = "com.sparrow";
32        ClassScanner classScanner = new ClassScanner();
33        Set<Class<?>> classSet = null;
34        try {
35            classSet = classScanner.getPackageAllClasses(basePackage, true);
36        } catch (IOException | ClassNotFoundException e) {
37            e.printStackTrace();
38        }
39        for (Class clazz : classSet) {
40            if (clazz.getAnnotation(Repository.class) != null) {
41                beanFactory.registerSingleton(StringUtil.firstCharToLowerCase(clazz.getSimpleName()), FacadeProxy.newMapperProxy(clazz));
42            }
43        }
44    }
45
46    @Override
47    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
48        this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
49    }
50}

10. invoke方法处理

在前面生成动态的代理的时候,可以看到,所有的invoke逻辑由StatementParser.java处理,这也是一个重量级的方法

invoke执行流程说明:

invoke(Object proxy, Method method, Object[] args)

  1. 得到方法的返回类型
  2. 得到方法的@Query注解,取得需要执行的sql语句,无法取到sql则抛异常
  3. 获得方法的参数,并将参数顺序对应为 ?1->arg0 ?2->arg1 …
  4. 获得方法的参数和参数上@Param注解,并将参数与对应的Param的名称关联:param1->arg0 password->arg1
  5. 判断sql是select还是其他,使用正则 (?i)select([\s\S]*?)
  6. 重写sql
  7. 如果不是 select 语句,判断是否是 @ReturnGeneratedKey 注解
  8. 如果无 @ReturnGeneratedKey 则直接执行语句并返回对应的结果
  9. 如有有 @ReturnGeneratedKey 并且是 insert 语句则返回生成的主键
  10. 如果是 select 语句,则执行select 语句,并根据方法的返回类型封装结果集

关于重写sql

代码语言:javascript
复制
 1@Query("insert into tb_user(id, 
 2account_name, 
 3password, 
 4uid, 
 5nick_name, 
 6register_time, 
 7update_time) values(
 8    :id, 
 9    :user.accountName, 
10    :user.password, 
11    :user.uid, 
12    :user.nickName, 
13    :user.registerTime, 
14    :user.updateTime )")
15    public int saveUser(@Param("id") Integer id, @Param("user") User user); 

首先获取sql

代码语言:javascript
复制
1insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time)
2values(:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime )

可以看出这并不是标准的 sql 也不是 jdbc 可以识别的sql,这里我们使用正则\?\d+(.[A-Za-z]+)?|:[A-Za-z0-9]+(.[A-Za-z]+)?

来提取出 ?1 ?2 :1 :2 :id :user.accountName 的特殊标志,并将其替换为 ?

替换过程中没替换一个 ?则记录对应的 ?所代表的数值

替换后的SQL为

代码语言:javascript
复制
1insert into tb_user(id, account_name, password, uid, nick_name, register_time, update_time)
2values(?, ?, ?, ?, ?, ?, ? )

这样的sql就可以被 jdbc 处理了 同时参数允许为:

代码语言:javascript
复制
1:id, :user.accountName, :user.password, :user.uid, :user.nickName, :user.registerTime, :user.updateTime

这里的 id 可以从参数中 id 直接获取, :user.accountName 则需要从参数 :user 即 user 中通过反射获取,这样 SQL 的重写就完成了

返回结果集封装可以通过反射,可以直接看下面代码

StatementParser.java

代码语言:javascript
复制
  1package com.zyndev.tool.fastsql.core;
  2
  3import com.sun.org.apache.bcel.internal.generic.IF_ACMPEQ;
  4import com.zyndev.tool.fastsql.annotation.Param;
  5import com.zyndev.tool.fastsql.annotation.Query;
  6import com.zyndev.tool.fastsql.annotation.ReturnGeneratedKey;
  7import com.zyndev.tool.fastsql.convert.BeanConvert;
  8import com.zyndev.tool.fastsql.convert.ListConvert;
  9import com.zyndev.tool.fastsql.convert.SetConvert;
 10import com.zyndev.tool.fastsql.util.BeanReflectionUtil;
 11import com.zyndev.tool.fastsql.util.StringUtil;
 12import org.apache.commons.logging.Log;
 13import org.apache.commons.logging.LogFactory;
 14import org.springframework.jdbc.core.*;
 15import org.springframework.jdbc.support.GeneratedKeyHolder;
 16import org.springframework.jdbc.support.KeyHolder;
 17import org.springframework.jdbc.support.rowset.SqlRowSet;
 18import sun.reflect.generics.reflectiveObjects.NotImplementedException;
 19
 20import java.lang.reflect.Method;
 21import java.lang.reflect.Parameter;
 22import java.lang.reflect.ParameterizedType;
 23import java.lang.reflect.Type;
 24import java.sql.Connection;
 25import java.sql.PreparedStatement;
 26import java.sql.SQLException;
 27import java.sql.Statement;
 28import java.util.HashMap;
 29import java.util.List;
 30import java.util.Map;
 31
 32/**
 33 * sql 语句解析
 34 * <p>
 35 * 暂时只能处理 select count(*) from tb_user 类似语句
 36 *
 37 * @author 张瑀楠 zyndev@gmail.com
 38 * @version 0.0.1
 39 * @since 2017 /12/23 下午12:11
 40 */
 41class StatementParser {
 42
 43    private final static Log logger = LogFactory.getLog(StatementParser.class);
 44
 45    private static PreparedStatementCreator getPreparedStatementCreator(final String sql, final Object[] args, final boolean returnKeys) {
 46        PreparedStatementCreator creator = new PreparedStatementCreator() {
 47
 48            @Override
 49            public PreparedStatement createPreparedStatement(Connection con) throws SQLException {
 50                PreparedStatement ps = con.prepareStatement(sql);
 51                if (returnKeys) {
 52                    ps = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
 53                } else {
 54                    ps = con.prepareStatement(sql);
 55                }
 56
 57                if (args != null) {
 58                    for (int i = 0; i < args.length; i++) {
 59                        Object arg = args[i];
 60                        if (arg instanceof SqlParameterValue) {
 61                            SqlParameterValue paramValue = (SqlParameterValue) arg;
 62                            StatementCreatorUtils.setParameterValue(ps, i + 1, paramValue,
 63                                    paramValue.getValue());
 64                        } else {
 65                            StatementCreatorUtils.setParameterValue(ps, i + 1,
 66                                    SqlTypeValue.TYPE_UNKNOWN, arg);
 67                        }
 68                    }
 69                }
 70                return ps;
 71            }
 72        };
 73        return creator;
 74    }
 75
 76    /**
 77     * 此处对 Repository 中方法进行解析,解析成对应的sql 和 参数
 78     * <p>
 79     * sql 来自于 @Query 注解的 value
 80     * 参数 来自方法的参数
 81     * <p>
 82     * 注意根据返回值的不同封装结果集
 83     *
 84     * @param proxy  执行对象
 85     * @param method 执行方法
 86     * @param args   参数
 87     * @return object
 88     */
 89    static Object invoke(Object proxy, Method method, Object[] args) throws Exception {
 90
 91        JdbcTemplate jdbcTemplate = DataSourceHolder.getInstance().getJdbcTemplate();
 92
 93        boolean logDebug = logger.isDebugEnabled();
 94
 95        String methodReturnType = method.getReturnType().getName();
 96        Query query = method.getAnnotation(Query.class);
 97
 98        if (null == query || StringUtil.isBlank(query.value())) {
 99            logger.error(method.toGenericString() + " 无 query 注解或 SQL 为空");
100            throw new IllegalStateException(method.toGenericString() + " 无 query 注解或 SQL 为空");
101        }
102
103        String originSql = query.value().trim();
104
105        System.out.println("sql:" + query.value());
106        Map<String, Object> namedParamMap = new HashMap<>();
107        Parameter[] parameters = method.getParameters();
108        if (args != null && args.length > 0) {
109            for (int i = 0; i < args.length; ++i) {
110                Param param = parameters[i].getAnnotation(Param.class);
111                if (null != param) {
112                    namedParamMap.put(param.value(), args[i]);
113                }
114                namedParamMap.put("?" + (i + 1), args[i]);
115            }
116        }
117
118        if (logDebug) {
119            logger.debug("执行 sql: " + originSql);
120        }
121
122        // 判断 sql 类型, 判断是否为 select 开头语句
123        boolean isQuery = originSql.trim().matches("(?i)select([\\s\\S]*?)");
124        Object[] params = null;
125        // rewrite sql
126        if (null != args && args.length > 0) {
127            List<String> results = StringUtil.matches(originSql, "\\?\\d+(\\.[A-Za-z]+)?|:[A-Za-z0-9]+(\\.[A-Za-z]+)?");
128            if (results.isEmpty()) {
129                params = args;
130            } else {
131                params = new Object[results.size()];
132                for (int i = 0; i < results.size(); ++i) {
133                    if (results.get(i).charAt(0) == ':') {
134                        originSql = originSql.replaceFirst(results.get(i), "?");
135                        // 判断是否是 param.param 的格式
136                        if (!results.get(i).contains(".")) {
137                            params[i] = namedParamMap.get(results.get(i).substring(1));
138                        } else {
139                            String[] paramArgs = results.get(i).split("\\.");
140                            Object param = namedParamMap.get(paramArgs[0].substring(1));
141                            params[i] = BeanReflectionUtil.getFieldValue(param, paramArgs[1]);
142                        }
143                        continue;
144                    }
145                    int paramIndex = Integer.parseInt(results.get(i).substring(1));
146                    originSql = originSql.replaceFirst("\\?" + paramIndex, "?");
147                    params[i] = namedParamMap.get(results.get(i));
148                }
149            }
150        }
151
152
153        System.out.println("execute sql:" + originSql);
154        System.out.print("params : ");
155        if (null != params) {
156            for (Object o : params) {
157                System.out.print(o + ",\t");
158            }
159        }
160        System.out.println("\n");
161
162
163        /**
164         * 如果返回值是基本类型或者其包装类
165         */
166        System.out.println(methodReturnType);
167        if (isQuery) {
168            // 查询方法
169            if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
170                return jdbcTemplate.queryForObject(originSql, params, Integer.class);
171            } else if ("java.lang.String".equals(methodReturnType)) {
172                return jdbcTemplate.queryForObject(originSql, params, String.class);
173            } else if ("java.util.List".equals(methodReturnType) || "java.util.Set".equals(methodReturnType)) {
174                String typeName = null;
175                Type returnType = method.getGenericReturnType();
176                if (returnType instanceof ParameterizedType) {
177                    Type[] types = ((ParameterizedType) returnType).getActualTypeArguments();
178                    if (null == types || types.length > 1) {
179                        throw new IllegalArgumentException("当返回值为 list 时,必须标明具体类型,且只有一个");
180                    }
181                    typeName = types[0].getTypeName();
182                }
183                Object obj = BeanReflectionUtil.newInstance(typeName);
184                SqlRowSet rowSet = jdbcTemplate.queryForRowSet(originSql, params);
185                if ("java.util.List".equals(methodReturnType)) {
186                    return ListConvert.convert(rowSet, obj);
187                }
188                return SetConvert.convert(rowSet, obj);
189            } else if ("java.util.Map".equals(methodReturnType)) {
190                throw new NotImplementedException();
191            } else {
192                SqlRowSet rowSet = jdbcTemplate.queryForRowSet(originSql, params);
193                Object obj = BeanReflectionUtil.newInstance(methodReturnType);
194                return BeanConvert.convert(rowSet, obj);
195            }
196        } else {
197            // 非查询方法
198            // 判断是否是insert 语句
199            ReturnGeneratedKey returnGeneratedKeyAnnotation = method.getAnnotation(ReturnGeneratedKey.class);
200            if (returnGeneratedKeyAnnotation == null) {
201                int retVal = jdbcTemplate.update(originSql, params);
202                if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
203                    return retVal;
204                } else if ("java.lang.Boolean".equals(methodReturnType)) {
205                    return retVal > 0;
206                }
207            } else {
208                // 判断是否是 insert 语句
209                boolean isInsertSql = originSql.trim().matches("(?i)insert([\\s\\S]*?)");
210                if (isInsertSql) {
211                    KeyHolder keyHolder = new GeneratedKeyHolder();
212                    PreparedStatementCreator preparedStatementCreator = getPreparedStatementCreator(originSql, params, true);
213                    jdbcTemplate.update(preparedStatementCreator, keyHolder);
214                    if ("java.lang.Integer".equals(methodReturnType) || "int".equals(methodReturnType)) {
215                        return keyHolder.getKey().intValue();
216                    } else if ("java.lang.Long".equals(methodReturnType) || "long".equals(methodReturnType)) {
217                        return keyHolder.getKey().longValue();
218                    }
219                    logger.error(method.toGenericString() + " 返回主键id应该为 int 或者 long 类型 ");
220                    throw new IllegalArgumentException(method.toGenericString() + " 返回主键id应该为 int 或者 long 类型 ");
221                } else {
222                    logger.error(method.toGenericString() + " 非 insert 语句 无法返回 GeneratedKey:sql语句为:" + originSql);
223                    throw new IllegalStateException(method.toGenericString() + " 非 insert 语句 无法返回 GeneratedKey:sql语句为:" + originSql);
224                }
225            }
226        }
227        return null;
228    }
229}

由此一个简单 ORM 就实现了,其实实现 ORM 并不难,难的是细心处理各种可能的 Bug

本文参与?腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-03-09,如有侵权请联系?cloudcommunity@tencent.com 删除

本文分享自 双鬼带单 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 实现一个 ORM 到底多简单
  • 原理
  • ORM 实现
    • 1. 通过注解来将 Java Bean 和数据库字段关联
      • 2. 反射工具类
        • 3. 简单的 model 示例
          • 4. 注解解析
            • 5. 数据库操作
              • 6. 结合反射实现查询操作
              • 使用动态代理实现 @Query @Select 类似功能
                • 1. 动态代理
                  • 2. 注解
                    • 3. 表设计
                      • 4. model
                        • 5. repository
                          • 7. 大体流程
                            • 8. 代理使用
                              • 9. 将生成代理放入 Spring IOC 容器中
                                • 10. invoke方法处理
                                相关产品与服务
                                数据库
                                云数据库为企业提供了完善的关系型数据库、非关系型数据库、分析型数据库和数据库生态工具。您可以通过产品选择和组合搭建,轻松实现高可靠、高可用性、高性能等数据库需求。云数据库服务也可大幅减少您的运维工作量,更专注于业务发展,让企业一站式享受数据上云及分布式架构的技术红利!
                                领券
                                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
                                http://www.vxiaotou.com