Go 在orm中使用反射

2024-09-08 13:38
文章标签 go 使用 反射 orm

本文主要是介绍Go 在orm中使用反射,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

作为静态语言,golang 稍显笨拙,还好 go 的标准包reflect(反射)包弥补了这点不足,它提供了一系列强大的 API,能够根据执行过程中对象的类型来改变程序控制流。本文将通过设计并实现一个简易的 mysql orm 来学习它,要求读者了解mysql基本知识,并且跟我一样至少已经接触 golang 两到三个月。

orm 这个概念相信同学们都非常熟悉,尤其是写过rails的同学,对active_record的强大肯定深有体会(得益于的method_missing和define_method方法,少写了海量代码),所以对 orm 我就不过多介绍了。本文要实现的 orm 只提供基本的CRUD(增删改查)和transaction(事务)功能,核心代码控制在 300 行左右。 如果想手把手照着写,需要先做一些准备工作。

源码github:https://github.com/zhenyz/simple_orm/blob/master/orm.go

准备工作

在本地 mysql 里create database orm_db,然后再create一张user表,结构如下:

CREATE TABLE `user` (`id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '自增主键',`age` smallint(10) unsigned NOT NULL DEFAULT 0 COMMENT '年龄',`first_name` varchar(45) NOT NULL DEFAULT '' COMMENT '姓',`last_name` varchar(45) NOT NULL DEFAULT '' COMMENT '名',`email` varchar(45) NOT NULL DEFAULT '' COMMENT '邮箱地址',`created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',`updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',PRIMARY KEY (`id`),KEY `idx_email` (`email`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='用户表';

同时,golang 代码里定义一个与之对应的struct:

type User struct {ID        int64     `json:"id"`         // 自增主键Age       int64     `json:"age"`        // 年龄FirstName string    `json:"first_name"` // 姓LastName  string    `json:"last_name"`  // 名Email     string    `json:"email"`      // 邮箱地址CreatedAt time.Time `json:"created_at"` // 创建时间UpdatedAt time.Time `json:"updated_at"` // 更新时间
}

与 mysql 交互需要用到一个 go 标准包和一个驱动,代码import如下:

package ormimport ("database/sql"//register driver_ "github.com/go-sql-driver/mysql"
)

首先按照database维度建立连接,写一个可以返回 mysql 连接的函数:

//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"
func Connect(dsn string) (*sql.DB, error) {conn, err := sql.Open("mysql", dsn)if err != nil {return nil, err}//设置连接池conn.SetMaxOpenConns(100)conn.SetMaxIdleConns(10)conn.SetConnMaxLifetime(10 * time.Minute)return conn, conn.Ping()
}

设计一个struct用于实现 orm(go 不是面向对象的语言,没有class):

//Query will build a sql
type Query struct {db      *sql.DBtable   string
}

最后将通过Query拼接出 sql 语句与 mysql 交互,所以写一个绑定函数:

//Table bind db and table
func Table(db *sql.DB, tableName string) func() *Query {return func() *Query {return &Query{db:    db,table: tableName,}}
}

返回值是一个闭包函数,这样使用时直接调用这个闭包函数就可以获取一个绑定好的 databasetable 的Query,比如现在有数据库orm_db和user表:

//全局变量ormDB和users
ormDB, _ := Connect("user:password@tcp(127.0.0.1:3306)/orm_db")
users := Table(ormDB, "user")
//调用
users().Insert(...)

准备工作到此完成,下面进入正题。

Insert 方法

首先分析一下标准insert语句:

insert into user (first_name, last_name) values ('Tom', 'Cat'), ('Tom', 'Cruise')

把 sql 语句中变化的部分抽象出来,其实就是key(字段)和value(值),那么 orm 里的Insert方法原型就有了,如下,参数是 struct 或者 map,因为它们都能提供键值对:

//Insert in can be *User, []*User, map[string]interface{}
func (q *Query) Insert(in interface{}) (int64, error) {var keys, values []stringv := reflect.ValueOf(in)//剥离指针for v.Kind() == reflect.Ptr {v = v.Elem()}switch v.Kind() {case reflect.Struct:keys, values = sKV(v)case reflect.Map:keys, values = mKV(v)case reflect.Slice:for i := 0; i < v.Len(); i++ {//Kind是切片时,可以用Index()方法遍历sv := v.Index(i)for sv.Kind() == reflect.Ptr || sv.Kind() == reflect.Interface {sv = sv.Elem()}//切片元素不是struct或者指针,报错if sv.Kind() != reflect.Struct {return 0, errors.New("method Insert error: in slice is not structs")}//keys只保存一次就行,因为后面的都一样了if len(keys) == 0 {keys, values = sKV(sv)continue}_, val := sKV(sv)values = append(values, val...)}default:return 0, errors.New("method Insert error: type error")}//todo//...
}

参数in可以是一个User(前文定义好的结构体)实例的指针(或者指针集合),也可以是一个 map,这两个结构都可以提供键值对,我们通过反射来分析它的类型,然后根据类型执行相应的逻辑。 reflect 包里的有两个重要结构Type和Value,Type 是一个接口,定义了所有类型相关的 api,reflect 里的*rtype实现了这个接口,通过 reflect.TypeOf 函数可以获取任何传入值的*rtype。Value 是一个 struct,通过 reflect.ValueOf 函数获取,它在*rtype的基础上又封装了传入值的 unsafe.Pointer 类型的地址以及这个值的元数据。 在 TypeValue 之上还有一个Kind,它代表传入值的原始类型,比如:

type myInt int
var i myInt
t := reflect.TypeOf(i)
k := t.Kind()

tmyInt,而 kint,Type 和 Kind 是不同的,这一点要注意区分。 如果 Type 的 Kind 是指针、接口、切片、map 等复合类型,可以调用 Elem()方法获取基类型。 如果 Value 的 Kind 是指针、接口,可以调用 Elem()方法获取实际值。 Value 上还定义了一个Interface()方法,它是 ValueOf()方法的反操作。 有了上面这些反射方法,我们可以封装一个sKV()函数,它专门处理 struct 类型的值,获取 key(取 json tag)和 value:

func sKV(v reflect.Value) ([]string, []string) {var keys, values []stringt := v.Type()for n := 0; n < t.NumField(); n++ {tf := t.Field(n)vf := v.Field(n)//忽略非导出字段if tf.Anonymous {continue}//忽略无效、零值字段if !vf.IsValid() || reflect.DeepEqual(vf.Interface(), reflect.Zero(vf.Type()).Interface()) {continue}for vf.Type().Kind() == reflect.Ptr {vf = vf.Elem()}//有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value//如果字段值是time类型之外的struct,递归获取keys和valuesif vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {cKeys, cValues := sKV(vf)keys = append(keys, cKeys...)values = append(values, cValues...)continue}//根据字段的json tag获取key,忽略无tag字段key := strings.Split(tf.Tag.Get("json"), ",")[0]if key == "" {continue}value := format(vf)if value != "" {keys = append(keys, key)values = append(values, value)}}return keys, values
}

sKV()函数里需要格式化字符串,那么定义一个format()函数。 time.Time类型怎么转化成各种数据库的时间类型我有点拿不准,所以需要对比时间类型的值时,一律用 unix 时间戳,感觉比较省事不会出错:

func format(v reflect.Value) string {//断言出time类型直接转unix时间戳if t, ok := v.Interface().(time.Time); ok {return fmt.Sprintf("FROM_UNIXTIME(%d)", t.Unix())}switch v.Kind() {case reflect.String:return fmt.Sprintf(`'%s'`, v.Interface())case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:return fmt.Sprintf(`%d`, v.Interface())case reflect.Float32, reflect.Float64:return fmt.Sprintf(`%f`, v.Interface())//如果是切片类型,遍历元素,递归格式化成"(, , , )"形式case reflect.Slice:var values []stringfor i := 0; i < v.Len(); i++ {values = append(values, format(v.Index(i)))}return fmt.Sprintf(`(%s)`, strings.Join(values, ","))//接口类型剥一层递归case reflect.Interface:return format(v.Elem())}return ""
}

map 类型处理起来和 struct 不同,所以我们再定义一个mKV()函数,目的和 sKV()一样,都是获取键值对:

func mKV(v reflect.Value) ([]string, []string) {var keys, values []string//获取map的key组成的切片mapKeys := v.MapKeys()for _, key := range mapKeys {value := format(v.MapIndex(key))if value != "" {values = append(values, value)keys = append(keys, key.Interface().(string))}}return keys, values
}

利用 sKV()mKV()函数取到键值对后,就得到了 insert 语句中的变化部分,补全 Insert()方法的todo部分:

//Insert in can be User, *User, []User, []*User, map[string]interface{}
func (q *Query) Insert(in interface{}) (int64, error) {//already donekl := len(keys)vl := len(values)if kl == 0 || vl == 0 {return 0, errors.New("method Insert error: no data")}var insertValue string//插入多条记录时需要用","拼接一下valuesif kl < vl {var tmpValues []stringfor kl <= vl {if kl%(len(keys)) == 0 {tmpValues = append(tmpValues, fmt.Sprintf("(%s)", strings.Join(values[kl-len(keys):kl], ",")))}kl++}insertValue = strings.Join(tmpValues, ",")} else {insertValue = fmt.Sprintf("(%s)", strings.Join(values, ","))}query := fmt.Sprintf(`insert into %s (%s) values %s`, q.table, strings.Join(keys, ","), insertValue)log.Printf("insert sql: %s", query)st, err := q.DB.Prepare(query)if err != nil {return 0, err}result, err := st.Exec()if err != nil {return 0, err}return result.LastInsertId()
}

原理很简单,利用反射分析参数,取键值对,然后拼接 sql 语句,再通过 mysql 驱动入库。 调用示例:

user1 := &User{Age:       30,FirstName: "Tom",LastName:  "Cat",
}
user2 := User{Age:       30,FirstName: "Tom",LastName:  "Curise",
}
user3 := User{Age:       30,FirstName: "Tom",LastName:  "Hanks",
}
user4 := map[string]interface{}{"age":        30,"first_name": "Tom","last_name":  "Zzy",
}
users().Insert([]interface{}{user1, user2})
users().Insert(user3)
users().Insert(user4)

增删改查的增部分到此完成,因为查询语句非常复杂多变,所以有了数据后,先进行查。

Select 方法

先分析一下标准select语句

select id, age from user where first_name = 'Tom' and last_name = 'Cat'

可见 sql 语句的变量部分是select后面的字段和where后面的键值对,所以我们需要一个Where()来方法构造查询条件,并且需要一个Select()方法最后执行查询,最终形成一个链式调用效果:

var user []User
users().Where(?).WhereNot(?).Limit(100).Offset(100).Order("id desc").Only("id", "age").Select(&user)

所以需要改造 Query 如下,增加属性用于暂存链式调用中添加的值:

//Query will build a sql
type Query struct {db     *sql.DBtable  stringwheres []stringonly   []stringlimit  stringoffset stringorder  stringerrs   []string
}

Query 添加 Where()方法,支持 structmap 参数,同时支持传如同"age > 10"形式的字符串:

//Where args can be string, User, *User, map[string]interface{}
func (q *Query) Where(wheres ...interface{}) *Query {for _, w := range wheres {v := reflect.ValueOf(w)for v.Kind() == reflect.Ptr {v = v.Elem()}switch v.Kind() {case reflect.String:q.wheres = append(q.wheres, w.(string))case reflect.Struct://todocase reflect.Map://tododefault:q.errs = append(q.errs, "method Where error: type error")}}return q
}

但是考虑到后面还会实现一个WhereNot()方法,所以把公共逻辑抽到一个where()函数里,并且直接复用之前的 sKV()mKv()函数获取键值对:

func where(eq bool, w interface{}) (string, error) {var keys, values []stringv := reflect.ValueOf(w)for v.Kind() == reflect.Ptr {v = v.Elem()}switch v.Kind() {case reflect.String:return w.(string), nilcase reflect.Struct:keys, values = sKV(v)case reflect.Map:keys, values = mKV(v)default:return "", errors.New("method Where error: type error")}if len(keys) != len(values) {return "", errors.New("method Where error: len(keys) not equal len(values))")}var wheres []string//之前的format()函数里,已经将切片类型值处理成"( , , ,)“形式for idx, key := range keys {if eq {if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {wheres = append(wheres, fmt.Sprintf("%s in %s", key, values[idx]))continue}wheres = append(wheres, fmt.Sprintf("%s = %s", key, values[idx]))continue}if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {wheres = append(wheres, fmt.Sprintf("%s not in %s", key, values[idx]))continue}wheres = append(wheres, fmt.Sprintf("%s != %s", key, values[idx]))}return strings.Join(wheres, " and "), nil
}

Where()方法最终变成:

//Where args can be string, User, *User, map[string]interface{}
func (q *Query) Where(wheres ...interface{}) *Query {for _, w := range wheres {str, err := where(true, w)q.wheres = append(q.wheres, str)if err != nil {//因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理q.errs = append(q.errs, err.Error())}}return q
}

WhereNot()把调用 where()的第一个参数改成 false 就行了,不贴代码了。 Limit()Offset()Order()Only()这几个方法也很简单:

//Limit .
func (q *Query) Limit(limit uint) *Query {q.limit = fmt.Sprintf("limit %d", limit)return q
}//Offset .
func (q *Query) Offset(offset uint) *Query {q.offset = fmt.Sprintf("offset %d", offset)return q
}//Order .
func (q *Query) Order(ord string) *Query {q.order = fmt.Sprintf("order by %s", ord)return q
}//Only 指定需要查询的字段
func (q *Query) Only(columns ...string) *Query {q.only = append(q.only, columns...)return q
}

有了上面这些条件之后,我们可以写一个toSQL()方法,把 Query 的属性组装成一条 sql 语句:

func (q *Query) toSQL() string {var where stringif len(q.wheres) > 0 {where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))}sqlStr := fmt.Sprintf(`select %s from %s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset)log.Printf("select sql: %s", sqlStr)return sqlStr
}

有了 sql 语句我们就可以查询数据了,但是想查一个表的全部字段时,为了方便,只需要传入对应的struct,比如user表对应的User,我们就直接分析这个 struct,取它的 tag 作为查询字段,而不需要再调用 Only()方法指定字段。 另外,因为 golang 中的参数传递全都是值传递,要修改传入值,必须传值的指针,这里要注意一点:

var user User
users.Select(&user)
var userPtr *User
users.Select(user)

这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。 综上,我们首先为 Select()方法做一下的参数检查,确保传入值是一个正确的指针,并确保 only 属性有值:

//Select dest must be a ptr, e.g. *user, *[]user, *[]*user, *map, *[]map, *int, *[]int
func (q *Query) Select(dest interface{}) error {if len(q.errs) != 0 {return errors.New(strings.Join(q.errs, "
"))}t := reflect.TypeOf(dest)v := reflect.ValueOf(dest)typeErr := errors.New("method Select error: type error")if t.Kind() != reflect.Ptr {return typeErr}//如果是用 var userPtr *User 方式声明的变量,则不可取址if !v.Elem().CanAddr() {return typeErr}t = t.Elem()v = v.Elem()//如果only此时仍然为空,说明Only()方法未被调用,我们从struct上取tag填充if len(q.only) == 0 {switch t.Kind() {case reflect.Struct:if t.Name() != "Time" {q.only = sK(v)}case reflect.Slice://获取切片的基本类型给一个局部变量t := t.Elem()if t.Kind() == reflect.Ptr {t = t.Elem()}if t.Kind() == reflect.Struct {if t.Name() != "Time" {q.only = sK(reflect.Zero(t))}}}}if len(q.only) == 0 {return errors.New("method Select error: type error, no columns to select")}if t.Kind() != reflect.Slice {q.limit = "limit 1"}//todo
}

这里只取 struct 的 tag,不取 value,我们定义一个新的 sK()函数:

func sK(v reflect.Value) []string {var keys []stringt := v.Type()for n := 0; n < t.NumField(); n++ {tf := t.Field(n)vf := v.Field(n)//忽略非导出字段if tf.Anonymous {continue}for vf.Type().Kind() == reflect.Ptr {vf = vf.Elem()}//如果字段值是time类型之外的struct,递归获取keysif vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {keys = append(keys, sK(vf)...)continue}//根据字段的json tag获取key,忽略无tag字段key := strings.Split(tf.Tag.Get("json"), ",")[0]if key == "" {continue}keys = append(keys, key)}return keys
}

现在 sql 语句已经完备了,可以执行最后的取值步骤了。 我们根据传入 Select()的指针的基类型生成实际数据,对其取址后交给 sql 包的Scan()方法填充,然后Set()回去,所以这里需要一个address()函数用于取址:

func address(dest reflect.Value, columns []string) []interface{} {dest = dest.Elem()t := dest.Type()addrs := make([]interface{}, 0)switch t.Kind() {case reflect.Struct:for n := 0; n < t.NumField(); n++ {tf := t.Field(n)vf := dest.Field(n)if tf.Anonymous {continue}for vf.Type().Kind() == reflect.Ptr {vf = vf.Elem()}//如果字段值是time类型之外的struct,递归取址if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {nVf := reflect.New(vf.Type())vf.Set(nVf.Elem())addrs = append(addrs, address(nVf, columns)...)continue}column := strings.Split(tf.Tag.Get("json"), ",")[0]if column == "" {continue}//只取选定的字段的地址for _, col := range columns {if col == column {addrs = append(addrs, vf.Addr().Interface())break}}}default:addrs = append(addrs, dest.Addr().Interface())}return addrs
}

Value.Addr()方法可用于取址,前提是Value.CanAddr()返回 true。 relfect.New()可以根据Type来new出一个Value,这个 Value 是一个指针,它的基值是可以取址的,把它的基值Set()到目标值上,就达到了根据 Type 从无到有生成对应值的目的。 因为 map 不能用 new()函数生成,所以需要写一个用于生成 map 的函数setMap()

//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言
func (q *Query) setMap(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {if t.Elem().Kind() != reflect.Interface {return reflect.ValueOf(nil), errors.New("method setMap error: type error, must be map[string]interface{}")}m := reflect.MakeMap(t)addrs := make([]interface{}, len(q.only))for idx := range q.only {addrs[idx] = new(interface{})}if err := rows.Scan(addrs...); err != nil {return reflect.ValueOf(nil), err}for idx, column := range q.only {//从指针剥出interface{},再剥出实际值m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem().Elem())}return m, nil
}

reflect.MakeMap()make()作用差不多,它接受一个Kind是reflect.Map的Type作为参数,生成一个对应类型的 map。 对于其它适用于new的类型,写一个通用的函数setElem()处理:

//适用于基类型和struct
func (q *Query) setElem(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {addrsErr := errors.New("method setElem error: columns not match addresses")dest := reflect.New(t)addrs := address(dest, q.only)if len(q.only) != len(addrs) {return reflect.ValueOf(nil), addrsErr}if err := rows.Scan(addrs...); err != nil {return reflect.ValueOf(nil), err}return dest, nil
}

这些函数完成后,就可以着手完善 Select()里的 todo 部分了:

//already done
rows, err := q.DB.Query(q.toSQL())if err != nil {return err}switch t.Kind() {case reflect.Slice:dt := t.Elem()for dt.Kind() == reflect.Ptr {dt = dt.Elem()}sl := reflect.MakeSlice(t, 0, 0)for rows.Next() {var destination reflect.Valueif dt.Kind() == reflect.Map {destination, err = q.setMap(rows, dt)} else {destination, err = q.setElem(rows, dt)}if err != nil {return err}//区分切片元素是否指针switch t.Elem().Kind() {case reflect.Ptr, reflect.Map:sl = reflect.Append(sl, destination)default:sl = reflect.Append(sl, destination.Elem())}}v.Set(sl)return nilcase reflect.Map:for rows.Next() {m, err := q.setMap(rows, t)if err != nil {return err}v.Set(m)}return nildefault:for rows.Next() {destination, err := q.setElem(rows, t)if err != nil {return err}v.Set(destination.Elem())}}return nil

至此,Select()方法就大功告成了,部分调用方式示例:

var user User
users()
.Where("first_name = 'Tom'", map[string]interface{}{"id": []int{1, 2, 3, 4},
})
.WhereNot(&User{LastName: "Cat"})
.Only("last_name")
.Select(&user)var userMore []User
users().Where("first_name = 'Tom'").Order("id desc").Select(&userMore)
var userMoreP []*User
users().Where("first_name = 'Tom'").Select(&userMoreP)
var lastName string
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&lastName)
var lastNames []string
users().Where(map[string]interface{}{"first_name": "Tom",
}).Only("last_name").Select(&lastNames)
var userM map[string]interface{}
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&userM)
var userMS []map[string]interface{}
users().Where("age > 10").Only("last_name", "age").Limit(100).Select(&userMS)

Update 方法

分析 update sql 语句:

update user set first_name = "z", last_name = "zy" where first_name = "Tom" and last_name = "Curise

比较简单,直接复用之前写的 sKV()mKV()函数:

//Update src can be *user, user, map[string]interface{}, string
func (q *Query) Update(src interface{}) (int64, error) {if len(q.errs) != 0 {return 0, errors.New(strings.Join(q.errs, "
"))}v := reflect.ValueOf(src)for v.Kind() == reflect.Ptr {v = v.Elem()}var toBeUpdated, where stringvar keys, values []stringswitch v.Kind() {case reflect.String:toBeUpdated = src.(string)case reflect.Struct:keys, values = sKV(v)case reflect.Map:keys, values = mKV(v)default:return 0, errors.New("method Update error: type error")}if toBeUpdated == "" {if len(keys) != len(values) {return 0, errors.New("method Update error: keys not match values")}var kvs []stringfor idx, key := range keys {kvs = append(kvs, fmt.Sprintf("%s = %s", key, values[idx]))}toBeUpdated = strings.Join(kvs, ",")}if len(q.wheres) > 0 {where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))}query := fmt.Sprintf("update %s set %s %s", q.table, toBeUpdated, where)st, err := q.DB.Prepare(query)if err != nil {return 0, err}result, err := st.Exec()if err != nil {return 0, err}return result.RowsAffected()
}

调用方式:

u1 := "age = 100"
u2 := map[string]interface{}{"age":        100,"first_name": "z","last_name":  "zy",
}
u3 := &User{Age:       100,FirstName: "z",LastName:  "zy",
}
_, _ = users().Where("age > 10").Update(u1)
_, _ = users().Where("age > 10").Update(u2)
_, _ = users().Where("age > 10").Update(u3)

Delete 方法

这个最简单,没啥好说的:

//Delete no args
func (q *Query) Delete() (int64, error) {if len(q.errs) != 0 {return 0, errors.New(strings.Join(q.errs, "
"))}var where stringif len(q.wheres) > 0 {where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))}st, err := q.DB.Prepare(fmt.Sprintf(`delete from %s %s`, q.table, where))if err != nil {return 0, err}result, err := st.Exec()if err != nil {return 0, err}return result.RowsAffected()
}

删除 id 为 1,2,3,4,并且 age 大于 10 的用户的调用方式:

w := map[string]interface{}{"id": []int{1, 2, 3, 4},
}
_, _ = users().Where(w, "age > 10").Delete()

最后,写一个简单的事务处理函数Transaction()

Transaction 函数

事务有三个关键动作begin,rollback,commit。 begin 后,要求所有操作要不全部成功,要不全部失败,所以我们要检查所有 error,一旦出现错误就 rollback,并且还要recover程序的 panic,发现 panic 时也要 rollback,直到最后确保无错,才能 commit。 调用*sql.DB.Begin()方法后,我们会得到一个事务具柄,事务内的 mysql 交互都要通过它来进行,它也实现了Query()Prepare()等方法。 所以我们定义一个接口:

//Dba *sql.DB or *sql.Tx
type Dba interface {Query(string, ...interface{}) (*sql.Rows, error)Prepare(string) (*sql.Stmt, error)
}

然后把Query结构体的DB属性的类型改成这个接口:

//Query will build a sql
type Query struct {DB     Dba...
}

同时, 改造Table()函数:

//Table bind db and table
func Table(db *sql.DB, tableName string) func(...Dba) *Query {return func(tx ...Dba) *Query {if len(tx) == 1 {return &Query{DB:    tx[0],table: tableName,}}return &Query{DB:    db,table: tableName,}}
}

这样我们就可以有选择性的和 mysql 进行普通交互或者事务交互。 然后把Transaction()函数写成这样:

//Transaction .
func Transaction(db *sql.DB, f func(Dba) error) (err error) {tx, err := db.Begin()if err != nil {return err}defer func() {p := recover()if err != nil {if rerr := tx.Rollback(); rerr != nil {panic(rerr)}return}if p != nil {if rerr := tx.Rollback(); rerr != nil {panic(rerr)}err = fmt.Errorf("function Transaction error: %v", p)return}if cerr := tx.Commit(); cerr != nil {panic(cerr)}}()err = f(tx)return err
}

第二个参数是一个接受事务具柄,返回 error 的函数,我们将需要事务的操作全部封装在这个函数里,就能抓到所有的 panic 和 error。 调用方式示例:

func doTx() error {ormDB, err := Connect("root@tcp(127.0.0.1:3306)/orm_db?parseTime=true&loc=Local")if err != nil {panic(err)}users := Table(ormDB, "user")args := something()//利用闭包传递变量f := func(tx Dba) error {var id int//select语句无需在事务具柄上进行if err := users().Where(args).Select(&id); err != nil {return err}//增删改需要在事务上进行if _, err = users(tx).Insert(args); err != nil {return err}if _, err = users(tx).Update(args); err != nil {return err}if _, err = users(tx).Where(args).Delete(); err != nil {return err}return nil}//开始事务if err := Transaction(ormDB, f); err != nil {return err}return nil
}

到此,这个迷你 orm 的增删改查和事务功能全部都实现了,代码大概 600 行,比我预想的多了一倍。

后记

golang 的反射虽然强大(其实并不,没有 ruby 的元编程那么方便),但还是比较烦琐的,而且类型不对时动不动就 panic,使用的时候要尽量检查一下 Kind。

这篇关于Go 在orm中使用反射的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1148262

相关文章

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

使用SQL语言查询多个Excel表格的操作方法

《使用SQL语言查询多个Excel表格的操作方法》本文介绍了如何使用SQL语言查询多个Excel表格,通过将所有Excel表格放入一个.xlsx文件中,并使用pandas和pandasql库进行读取和... 目录如何用SQL语言查询多个Excel表格如何使用sql查询excel内容1. 简介2. 实现思路3

java脚本使用不同版本jdk的说明介绍

《java脚本使用不同版本jdk的说明介绍》本文介绍了在Java中执行JavaScript脚本的几种方式,包括使用ScriptEngine、Nashorn和GraalVM,ScriptEngine适用... 目录Java脚本使用不同版本jdk的说明1.使用ScriptEngine执行javascript2.

c# checked和unchecked关键字的使用

《c#checked和unchecked关键字的使用》C#中的checked关键字用于启用整数运算的溢出检查,可以捕获并抛出System.OverflowException异常,而unchecked... 目录在 C# 中,checked 关键字用于启用整数运算的溢出检查。默认情况下,C# 的整数运算不会自

在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码

《在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码》在MyBatis的XML映射文件中,trim元素用于动态添加SQL语句的一部分,处理前缀、后缀及多余的逗号或连接符,示... 在MyBATis的XML映射文件中,<trim>元素用于动态地添加SQL语句的一部分,例如SET或W

Mybatis官方生成器的使用方式

《Mybatis官方生成器的使用方式》本文详细介绍了MyBatisGenerator(MBG)的使用方法,通过实际代码示例展示了如何配置Maven插件来自动化生成MyBatis项目所需的实体类、Map... 目录1. MyBATis Generator 简介2. MyBatis Generator 的功能3

Go语言实现将中文转化为拼音功能

《Go语言实现将中文转化为拼音功能》这篇文章主要为大家详细介绍了Go语言中如何实现将中文转化为拼音功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 有这么一个需求:新用户入职 创建一系列账号比较麻烦,打算通过接口传入姓名进行初始化。想把姓名转化成拼音。因为有些账号即需要中文也需要英

Python中使用defaultdict和Counter的方法

《Python中使用defaultdict和Counter的方法》本文深入探讨了Python中的两个强大工具——defaultdict和Counter,并详细介绍了它们的工作原理、应用场景以及在实际编... 目录引言defaultdict的深入应用什么是defaultdictdefaultdict的工作原理

使用Python进行文件读写操作的基本方法

《使用Python进行文件读写操作的基本方法》今天的内容来介绍Python中进行文件读写操作的方法,这在学习Python时是必不可少的技术点,希望可以帮助到正在学习python的小伙伴,以下是Pyth... 目录一、文件读取:二、文件写入:三、文件追加:四、文件读写的二进制模式:五、使用 json 模块读写

Python使用qrcode库实现生成二维码的操作指南

《Python使用qrcode库实现生成二维码的操作指南》二维码是一种广泛使用的二维条码,因其高效的数据存储能力和易于扫描的特点,广泛应用于支付、身份验证、营销推广等领域,Pythonqrcode库是... 目录一、安装 python qrcode 库二、基本使用方法1. 生成简单二维码2. 生成带 Log