From babb5963b5f6c0663a1108503ef919de5fc1041d Mon Sep 17 00:00:00 2001 From: starainrt Date: Sat, 1 Jul 2023 18:19:58 +0800 Subject: [PATCH] init --- column.go | 318 +++++++++++++++ doc.go | 25 ++ expression.go | 733 +++++++++++++++++++++++++++++++++++ go.mod | 8 + go.sum | 33 ++ statement.go | 1022 +++++++++++++++++++++++++++++++++++++++++++++++++ table.go | 321 ++++++++++++++++ test_utils.go | 26 ++ types.go | 79 ++++ 9 files changed, 2565 insertions(+) create mode 100644 column.go create mode 100644 doc.go create mode 100644 expression.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 statement.go create mode 100644 table.go create mode 100644 test_utils.go create mode 100644 types.go diff --git a/column.go b/column.go new file mode 100644 index 0000000..ed4fee6 --- /dev/null +++ b/column.go @@ -0,0 +1,318 @@ +// Modeling of columns + +package sqlbuilder + +import ( + "bytes" + "regexp" + + "github.com/dropbox/godropbox/errors" +) + +// XXX: Maybe add UIntColumn + +// Representation of a table for query generation +type Column interface { + isProjectionInterface + + Name() string + // Serialization for use in column lists + SerializeSqlForColumnList(out *bytes.Buffer) error + // Serialization for use in an expression (Clause) + SerializeSql(out *bytes.Buffer) error + + // Internal function for tracking table that a column belongs to + // for the purpose of serialization + setTableName(table string) error +} + +type NullableColumn bool + +const ( + Nullable NullableColumn = true + NotNullable NullableColumn = false +) + +// A column that can be refer to outside of the projection list +type NonAliasColumn interface { + Column + isOrderByClauseInterface + isExpressionInterface +} + +type Collation string + +const ( + UTF8CaseInsensitive Collation = "utf8_unicode_ci" + UTF8CaseSensitive Collation = "utf8_unicode" + UTF8Binary Collation = "utf8_bin" +) + +// Representation of MySQL charsets +type Charset string + +const ( + UTF8 Charset = "utf8" +) + +// The base type for real materialized columns. +type baseColumn struct { + isProjection + isExpression + name string + nullable NullableColumn + table string +} + +func (c *baseColumn) Name() string { + return c.name +} + +func (c *baseColumn) setTableName(table string) error { + c.table = table + return nil +} + +/*func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + if c.table != "" { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.table) + _, _ = out.WriteString("`.") + } + _, _ = out.WriteString("`") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} +*/ +func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + // Momo modified. we don't need prefixing table name + /* + if c.table != "" { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.table) + _, _ = out.WriteString("`.") + } + */ + _, _ = out.WriteString("`") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { + return c.SerializeSqlForColumnList(out) +} + +type bytesColumn struct { + baseColumn + isExpression +} + +// Representation of VARBINARY/BLOB columns +// This function will panic if name is not valid +func BytesColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in bytes column") + } + bc := &bytesColumn{} + bc.name = name + bc.nullable = nullable + return bc +} + +type stringColumn struct { + baseColumn + isExpression + charset Charset + collation Collation +} + +// Representation of VARCHAR/TEXT columns +// This function will panic if name is not valid +func StrColumn( + name string, + charset Charset, + collation Collation, + nullable NullableColumn) NonAliasColumn { + + if !validIdentifierName(name) { + panic("Invalid column name in str column") + } + sc := &stringColumn{charset: charset, collation: collation} + sc.name = name + sc.nullable = nullable + return sc +} + +type dateTimeColumn struct { + baseColumn + isExpression +} + +// Representation of DateTime columns +// This function will panic if name is not valid +func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in datetime column") + } + dc := &dateTimeColumn{} + dc.name = name + dc.nullable = nullable + return dc +} + +type integerColumn struct { + baseColumn + isExpression +} + +// Representation of any integer column +// This function will panic if name is not valid +func IntColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in int column") + } + ic := &integerColumn{} + ic.name = name + ic.nullable = nullable + return ic +} + +type doubleColumn struct { + baseColumn + isExpression +} + +// Representation of any double column +// This function will panic if name is not valid +func DoubleColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in int column") + } + ic := &doubleColumn{} + ic.name = name + ic.nullable = nullable + return ic +} + +type booleanColumn struct { + baseColumn + isExpression + + // XXX: Maybe allow isBoolExpression (for now, not included because + // the deferred lookup equivalent can never be isBoolExpression) +} + +// Representation of TINYINT used as a bool +// This function will panic if name is not valid +func BoolColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in bool column") + } + bc := &booleanColumn{} + bc.name = name + bc.nullable = nullable + return bc +} + +type aliasColumn struct { + baseColumn + expression Expression +} + +func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + if !validIdentifierName(c.name) { + return errors.Newf( + "Invalid alias name `%s`. Generated sql: %s", + c.name, + out.String()) + } + if c.expression == nil { + return errors.Newf( + "Cannot alias a nil expression. Generated sql: %s", + out.String()) + } + + _ = out.WriteByte('(') + if c.expression == nil { + return errors.Newf("nil alias clause. Generate sql: %s", out.String()) + } + if err := c.expression.SerializeSql(out); err != nil { + return err + } + _, _ = out.WriteString(") AS `") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *aliasColumn) setTableName(table string) error { + return errors.Newf( + "Alias column '%s' should never have setTableName called on it", + c.name) +} + +// Representation of aliased clauses (expression AS name) +func Alias(name string, c Expression) Column { + ac := &aliasColumn{} + ac.name = name + ac.expression = c + return ac +} + +// This is a strict subset of the actual allowed identifiers +var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$") + +// Returns true if the given string is suitable as an identifier. +func validIdentifierName(name string) bool { + //return validIdentifierRegexp.MatchString(name) + return true +} + +// Pseudo Column type returned by table.C(name) +type deferredLookupColumn struct { + isProjection + isExpression + table *Table + colName string + + cachedColumn NonAliasColumn +} + +func (c *deferredLookupColumn) Name() string { + return c.colName +} + +func (c *deferredLookupColumn) SerializeSqlForColumnList( + out *bytes.Buffer) error { + + return c.SerializeSql(out) +} + +func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { + if c.cachedColumn != nil { + return c.cachedColumn.SerializeSql(out) + } + + col, err := c.table.getColumn(c.colName) + if err != nil { + return err + } + + c.cachedColumn = col + return col.SerializeSql(out) +} + +func (c *deferredLookupColumn) setTableName(table string) error { + return errors.Newf( + "Lookup column '%s' should never have setTableName called on it", + c.colName) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..3f9170a --- /dev/null +++ b/doc.go @@ -0,0 +1,25 @@ +// A library for generating sql programmatically. +// +// SQL COMPATIBILITY NOTE: sqlbuilder is designed to generate valid MySQL sql +// statements. The generated statements may not work for other sql variants. +// For instances, the generated statements does not currently work for +// PostgreSQL since column identifiers are escaped with backquotes. +// Patches to support other sql flavors are welcome! (see +// https://godropbox/issues/33 for additional details). +// +// Known limitations for SELECT queries: +// - does not support subqueries (since mysql is bad at it) +// - does not currently support join table alias (and hence self join) +// - does not support NATURAL joins and join USING +// +// Known limitation for INSERT statements: +// - does not support "INSERT INTO SELECT" +// +// Known limitation for UPDATE statements: +// - does not support update without a WHERE clause (since it is dangerous) +// - does not support multi-table update +// +// Known limitation for DELETE statements: +// - does not support delete without a WHERE clause (since it is dangerous) +// - does not support multi-table delete +package sqlbuilder diff --git a/expression.go b/expression.go new file mode 100644 index 0000000..2a8d877 --- /dev/null +++ b/expression.go @@ -0,0 +1,733 @@ +// Query building functions for expression components +package sqlbuilder + +import ( + "bytes" + "reflect" + "strconv" + "strings" + "time" + + "b612.me/mysql/sqltypes" + "github.com/dropbox/godropbox/errors" +) + +type orderByClause struct { + isOrderByClause + expression Expression + ascent bool +} + +func (o *orderByClause) SerializeSql(out *bytes.Buffer) error { + if o.expression == nil { + return errors.Newf( + "nil order by clause. Generated sql: %s", + out.String()) + } + + if err := o.expression.SerializeSql(out); err != nil { + return err + } + + if o.ascent { + _, _ = out.WriteString(" ASC") + } else { + _, _ = out.WriteString(" DESC") + } + + return nil +} + +func Asc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: true} +} + +func Desc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: false} +} + +// Representation of an escaped literal +type literalExpression struct { + isExpression + value sqltypes.Value +} + +func (c literalExpression) SerializeSql(out *bytes.Buffer) error { + sqltypes.Value(c.value).EncodeSql(out) + return nil +} + +func serializeClauses( + clauses []Clause, + separator []byte, + out *bytes.Buffer) (err error) { + + if clauses == nil || len(clauses) == 0 { + return errors.Newf("Empty clauses. Generated sql: %s", out.String()) + } + + if clauses[0] == nil { + return errors.Newf("nil clause. Generated sql: %s", out.String()) + } + if err = clauses[0].SerializeSql(out); err != nil { + return + } + + for _, c := range clauses[1:] { + _, _ = out.Write(separator) + + if c == nil { + return errors.Newf("nil clause. Generated sql: %s", out.String()) + } + if err = c.SerializeSql(out); err != nil { + return + } + } + + return nil +} + +// Representation of n-ary conjunctions (AND/OR) +type conjunctExpression struct { + isExpression + isBoolExpression + expressions []BoolExpression + conjunction []byte +} + +func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(conj.expressions) == 0 { + return errors.Newf( + "Empty conjunction. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) + for i, expr := range conj.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, conj.conjunction, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +// Representation of n-ary arithmetic (+ - * /) +type arithmeticExpression struct { + isExpression + expressions []Expression + operator []byte +} + +func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(arith.expressions) == 0 { + return errors.Newf( + "Empty arithmetic expression. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(arith.expressions), len(arith.expressions)) + for i, expr := range arith.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, arith.operator, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +type tupleExpression struct { + isExpression + elements listClause +} + +func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer) error { + if len(tuple.elements.clauses) < 1 { + return errors.Newf("Tuples must include at least one element") + } + return tuple.elements.SerializeSql(out) +} + +func Tuple(exprs ...Expression) Expression { + clauses := make([]Clause, 0, len(exprs)) + for _, expr := range exprs { + clauses = append(clauses, expr) + } + return &tupleExpression{ + elements: listClause{ + clauses: clauses, + includeParentheses: true, + }, + } +} + +// Representation of a tuple enclosed, comma separated list of clauses +type listClause struct { + clauses []Clause + includeParentheses bool +} + +func (list *listClause) SerializeSql(out *bytes.Buffer) error { + if list.includeParentheses { + _ = out.WriteByte('(') + } + + if err := serializeClauses(list.clauses, []byte(","), out); err != nil { + return err + } + + if list.includeParentheses { + _ = out.WriteByte(')') + } + return nil +} + +// A not expression which negates a expression value +type negateExpression struct { + isExpression + isBoolExpression + + nested BoolExpression +} + +func (c *negateExpression) SerializeSql(out *bytes.Buffer) (err error) { + _, _ = out.WriteString("NOT (") + + if c.nested == nil { + return errors.Newf("nil nested. Generated sql: %s", out.String()) + } + if err = c.nested.SerializeSql(out); err != nil { + return + } + + _ = out.WriteByte(')') + return nil +} + +// Returns a representation of "not expr" +func Not(expr BoolExpression) BoolExpression { + return &negateExpression{ + nested: expr, + } +} + +// Representation of binary operations (e.g. comparisons, arithmetic) +type binaryExpression struct { + isExpression + lhs, rhs Expression + operator []byte +} + +func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) { + if c.lhs == nil { + return errors.Newf("nil lhs. Generated sql: %s", out.String()) + } + if err = c.lhs.SerializeSql(out); err != nil { + return + } + + _, _ = out.Write(c.operator) + + if c.rhs == nil { + return errors.Newf("nil rhs. Generated sql: %s", out.String()) + } + if err = c.rhs.SerializeSql(out); err != nil { + return + } + + return nil +} + +// A binary expression that evaluates to a boolean value. +type boolExpression struct { + isBoolExpression + binaryExpression +} + +func newBoolExpression(lhs, rhs Expression, operator []byte) *boolExpression { + // go does not allow {} syntax for initializing promoted fields ... + expr := new(boolExpression) + expr.lhs = lhs + expr.rhs = rhs + expr.operator = operator + return expr +} + +type funcExpression struct { + isExpression + funcName string + args *listClause +} + +func (c *funcExpression) SerializeSql(out *bytes.Buffer) (err error) { + if !validIdentifierName(c.funcName) { + return errors.Newf( + "Invalid function name: %s. Generated sql: %s", + c.funcName, + out.String()) + } + _, _ = out.WriteString(c.funcName) + if c.args == nil { + _, _ = out.WriteString("()") + } else { + return c.args.SerializeSql(out) + } + return nil +} + +// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) +func SqlFunc(funcName string, expressions ...Expression) Expression { + f := &funcExpression{ + funcName: funcName, + } + if len(expressions) > 0 { + args := make([]Clause, len(expressions), len(expressions)) + for i, expr := range expressions { + args[i] = expr + } + + f.args = &listClause{ + clauses: args, + includeParentheses: true, + } + } + return f +} + +type intervalExpression struct { + isExpression + duration time.Duration + negative bool +} + +var intervalSep = ":" + +func (c *intervalExpression) SerializeSql(out *bytes.Buffer) (err error) { + hours := c.duration / time.Hour + minutes := (c.duration % time.Hour) / time.Minute + sec := (c.duration % time.Minute) / time.Second + msec := (c.duration % time.Second) / time.Microsecond + _, _ = out.WriteString("INTERVAL '") + if c.negative { + _, _ = out.WriteString("-") + } + _, _ = out.WriteString(strconv.FormatInt(int64(hours), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(sec), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(msec), 10)) + _, _ = out.WriteString("' HOUR_MICROSECOND") + return nil +} + +// Interval returns a representation of duration +// in a form "INTERVAL `hour:min:sec:microsec` HOUR_MICROSECOND" +func Interval(duration time.Duration) Expression { + negative := false + if duration < 0 { + negative = true + duration = -duration + } + return &intervalExpression{ + duration: duration, + negative: negative, + } +} + +var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%") + +func EscapeForLike(s string) string { + return likeEscaper.Replace(s) +} + +// Returns an escaped literal string +func Literal(v interface{}) Expression { + value, err := sqltypes.BuildValue(v) + if err != nil { + panic(errors.Wrap(err, "Invalid literal value")) + } + return &literalExpression{value: value} +} + +// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses +func And(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" AND "), + } +} + +// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses +func Or(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" OR "), + } +} + +func Like(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" LIKE ")) +} + +func LikeL(lhs Expression, val string) BoolExpression { + return Like(lhs, Literal(val)) +} + +func Regexp(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" REGEXP ")) +} + +func RegexpL(lhs Expression, val string) BoolExpression { + return Regexp(lhs, Literal(val)) +} + +// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses +func Add(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" + "), + } +} + +// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses +func Sub(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" - "), + } +} + +// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses +func Mul(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" * "), + } +} + +// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses +func Div(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" / "), + } +} + +// Returns a representation of "a=b" +func Eq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS ")) + } + return newBoolExpression(lhs, rhs, []byte("=")) +} + +// Returns a representation of "a=b", where b is a literal +func EqL(lhs Expression, val interface{}) BoolExpression { + return Eq(lhs, Literal(val)) +} + +// Returns a representation of "a!=b" +func Neq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS NOT ")) + } + return newBoolExpression(lhs, rhs, []byte("!=")) +} + +// Returns a representation of "a!=b", where b is a literal +func NeqL(lhs Expression, val interface{}) BoolExpression { + return Neq(lhs, Literal(val)) +} + +// Returns a representation of "ab" +func Gt(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">")) +} + +// Returns a representation of "a>b", where b is a literal +func GtL(lhs Expression, val interface{}) BoolExpression { + return Gt(lhs, Literal(val)) +} + +// Returns a representation of "a>=b" +func Gte(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">=")) +} + +// Returns a representation of "a>=b", where b is a literal +func GteL(lhs Expression, val interface{}) BoolExpression { + return Gte(lhs, Literal(val)) +} + +func BitOr(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" | "), + } +} + +func BitAnd(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" & "), + } +} + +func BitXor(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" ^ "), + } +} + +func Plus(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" + "), + } +} + +func Minus(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" - "), + } +} + +// in expression representation +type inExpression struct { + isExpression + isBoolExpression + + lhs Expression + rhs *listClause + + err error +} + +func (c *inExpression) SerializeSql(out *bytes.Buffer) error { + if c.err != nil { + return errors.Wrap(c.err, "Invalid IN expression") + } + + if c.lhs == nil { + return errors.Newf( + "lhs of in expression is nil. Generated sql: %s", + out.String()) + } + + // We'll serialize the lhs even if we don't need it to ensure no error + buf := &bytes.Buffer{} + + err := c.lhs.SerializeSql(buf) + if err != nil { + return err + } + + if c.rhs == nil { + _, _ = out.WriteString("FALSE") + return nil + } + + _, _ = out.WriteString(buf.String()) + _, _ = out.WriteString(" IN ") + + err = c.rhs.SerializeSql(out) + if err != nil { + return err + } + + return nil +} + +// Returns a representation of "a IN (b[0], ..., b[n-1])", where b is a list +// of literals valList must be a slice type +func In(lhs Expression, valList interface{}) BoolExpression { + var clauses []Clause + switch val := valList.(type) { + // This atrocious body of copy-paste code is due to the fact that if you + // try to merge the cases, you can't treat val as a list + case []int: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []float64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []string: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case [][]byte: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []time.Time: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Numeric: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Fractional: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.String: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Value: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + default: + return &inExpression{ + err: errors.Newf( + "Unknown value list type in IN clause: %s", + reflect.TypeOf(valList)), + } + } + + expr := &inExpression{lhs: lhs} + if len(clauses) > 0 { + expr.rhs = &listClause{clauses: clauses, includeParentheses: true} + } + return expr +} + +type ifExpression struct { + isExpression + conditional BoolExpression + trueExpression Expression + falseExpression Expression +} + +func (exp *ifExpression) SerializeSql(out *bytes.Buffer) error { + _, _ = out.WriteString("IF(") + _ = exp.conditional.SerializeSql(out) + _, _ = out.WriteString(",") + _ = exp.trueExpression.SerializeSql(out) + _, _ = out.WriteString(",") + _ = exp.falseExpression.SerializeSql(out) + _, _ = out.WriteString(")") + return nil +} + +// Returns a representation of an if-expression, of the form: +// +// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE) +func If(conditional BoolExpression, + trueExpression Expression, + falseExpression Expression) Expression { + return &ifExpression{ + conditional: conditional, + trueExpression: trueExpression, + falseExpression: falseExpression, + } +} + +type columnValueExpression struct { + isExpression + column NonAliasColumn +} + +func ColumnValue(col NonAliasColumn) Expression { + return &columnValueExpression{ + column: col, + } +} + +func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error { + _, _ = out.WriteString("VALUES(") + _ = cv.column.SerializeSqlForColumnList(out) + _ = out.WriteByte(')') + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c17344a --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module b612.me/mysql/sqlbuilder + +go 1.20 + +require ( + b612.me/mysql/sqltypes v0.0.0-20230701101652-40406d9a2ff2 + github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c46d44f --- /dev/null +++ b/go.sum @@ -0,0 +1,33 @@ +b612.me/mysql/sqltypes v0.0.0-20230701101652-40406d9a2ff2 h1:gWGuBHC7hrmyhp9vX1UOOX5C9WRoFHPDjSvRfmP/nS4= +b612.me/mysql/sqltypes v0.0.0-20230701101652-40406d9a2ff2/go.mod h1:Py9XWC9lc2cDhzfSPO7gqk07qZcPpJfk0aQ0iUZC5CQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd h1:s2vYw+2c+7GR1ccOaDuDcKsmNB/4RIxyu5liBm1VRbs= +github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd/go.mod h1:Vr/Q4p40Kce7JAHDITjDhiy/zk07W4tqD5YVi5FD0PA= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..92990b0 --- /dev/null +++ b/statement.go @@ -0,0 +1,1022 @@ +package sqlbuilder + +import ( + "bytes" + "fmt" + "regexp" + + "github.com/dropbox/godropbox/errors" +) + +type Statement interface { + // String returns generated SQL as string. + String(database string) (sql string, err error) +} + +type SelectStatement interface { + Statement + + Where(expression BoolExpression) SelectStatement + AndWhere(expression BoolExpression) SelectStatement + GroupBy(expressions ...Expression) SelectStatement + OrderBy(clauses ...OrderByClause) SelectStatement + Limit(limit int64) SelectStatement + Distinct() SelectStatement + WithSharedLock() SelectStatement + ForUpdate() SelectStatement + Offset(offset int64) SelectStatement + Comment(comment string) SelectStatement + Copy() SelectStatement +} + +type InsertStatement interface { + Statement + + // Add a row of values to the insert statement. + Add(row ...Expression) InsertStatement + AddOnDuplicateKeyUpdate(col NonAliasColumn, expr Expression) InsertStatement + Comment(comment string) InsertStatement + IgnoreDuplicates(ignore bool) InsertStatement +} + +// By default, rows selected by a UNION statement are out-of-order +// If you have an ORDER BY on an inner SELECT statement, the only thing +// it affects is the LIMIT clause on that inner statement (the ordering will +// still be out-of-order). +type UnionStatement interface { + Statement + + // Warning! You cannot include table names for the next 4 clauses, or + // you'll get errors like: + // Table 'server_file_journal' from one of the SELECTs cannot be used in + // global ORDER clause + Where(expression BoolExpression) UnionStatement + AndWhere(expression BoolExpression) UnionStatement + GroupBy(expressions ...Expression) UnionStatement + OrderBy(clauses ...OrderByClause) UnionStatement + + Limit(limit int64) UnionStatement + Offset(offset int64) UnionStatement +} + +type UpdateStatement interface { + Statement + + Set(column NonAliasColumn, expression Expression) UpdateStatement + Where(expression BoolExpression) UpdateStatement + OrderBy(clauses ...OrderByClause) UpdateStatement + Limit(limit int64) UpdateStatement + Comment(comment string) UpdateStatement +} + +type DeleteStatement interface { + Statement + + Where(expression BoolExpression) DeleteStatement + OrderBy(clauses ...OrderByClause) DeleteStatement + Limit(limit int64) DeleteStatement + Comment(comment string) DeleteStatement +} + +// LockStatement is used to take Read/Write lock on tables. +// See http://dev.mysql.com/doc/refman/5.0/en/lock-tables.html +type LockStatement interface { + Statement + + AddReadLock(table *Table) LockStatement + AddWriteLock(table *Table) LockStatement +} + +// UnlockStatement can be used to release table locks taken using LockStatement. +// NOTE: You can not selectively release a lock and continue to hold lock on +// another table. UnlockStatement releases all the lock held in the current +// session. +type UnlockStatement interface { + Statement +} + +// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID. +type GtidNextStatement interface { + Statement +} + +// +// UNION SELECT Statement ====================================================== +// + +func Union(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: true, + } +} + +func UnionAll(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: false, + } +} + +// Similar to selectStatementImpl, but less complete +type unionStatementImpl struct { + selects []SelectStatement + where BoolExpression + group *listClause + order *listClause + limit, offset int64 + // True if results of the union should be deduped. + unique bool +} + +func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { + us.where = expression + return us +} + +// Further filter the query, instead of replacing the filter +func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { + if us.where == nil { + return us.Where(expression) + } + us.where = And(us.where, expression) + return us +} + +func (us *unionStatementImpl) GroupBy( + expressions ...Expression) UnionStatement { + + us.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + us.group.clauses[i] = e + } + return us +} + +func (us *unionStatementImpl) OrderBy( + clauses ...OrderByClause) UnionStatement { + + us.order = newOrderByListClause(clauses...) + return us +} + +func (us *unionStatementImpl) Limit(limit int64) UnionStatement { + us.limit = limit + return us +} + +func (us *unionStatementImpl) Offset(offset int64) UnionStatement { + us.offset = offset + return us +} + +func (us *unionStatementImpl) String(database string) (sql string, err error) { + if len(us.selects) == 0 { + return "", errors.Newf("Union statement must have at least one SELECT") + } + + if len(us.selects) == 1 { + return us.selects[0].String(database) + } + + // Union statements in MySQL require that the same number of columns in each subquery + var projections []Projection + + for _, statement := range us.selects { + // do a type assertion to get at the underlying struct + statementImpl, ok := statement.(*selectStatementImpl) + if !ok { + return "", errors.Newf( + "Expected inner select statement to be of type " + + "selectStatementImpl") + } + + // check that for limit for statements with order by clauses + if statementImpl.order != nil && statementImpl.limit < 0 { + return "", errors.Newf( + "All inner selects in Union statement must have LIMIT if " + + "they have ORDER BY") + } + + // check number of projections + if projections == nil { + projections = statementImpl.projections + } else { + if len(projections) != len(statementImpl.projections) { + return "", errors.Newf( + "All inner selects in Union statement must select the " + + "same number of columns. For sanity, you probably " + + "want to select the same table columns in the same " + + "order. If you are selecting on multiple tables, " + + "use Null to pad to the right number of fields.") + } + } + } + + buf := new(bytes.Buffer) + for i, statement := range us.selects { + if i != 0 { + if us.unique { + _, _ = buf.WriteString(" UNION ") + } else { + _, _ = buf.WriteString(" UNION ALL ") + } + } + _, _ = buf.WriteString("(") + selectSql, err := statement.String(database) + if err != nil { + return "", err + } + _, _ = buf.WriteString(selectSql) + _, _ = buf.WriteString(")") + } + + if us.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = us.where.SerializeSql(buf); err != nil { + return + } + } + + if us.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = us.group.SerializeSql(buf); err != nil { + return + } + } + + if us.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = us.order.SerializeSql(buf); err != nil { + return + } + } + + if us.limit >= 0 { + if us.offset >= 0 { + _, _ = buf.WriteString( + fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) + } + } + return buf.String(), nil +} + +// +// SELECT Statement ============================================================ +// + +func newSelectStatement( + table ReadableTable, + projections []Projection) SelectStatement { + + return &selectStatementImpl{ + table: table, + projections: projections, + limit: -1, + offset: -1, + withSharedLock: false, + forUpdate: false, + distinct: false, + } +} + +// NOTE: SelectStatement purposely does not implement the Table interface since +// mysql's subquery performance is horrible. +type selectStatementImpl struct { + table ReadableTable + projections []Projection + where BoolExpression + group *listClause + order *listClause + comment string + limit, offset int64 + withSharedLock bool + forUpdate bool + distinct bool +} + +func (s *selectStatementImpl) Copy() SelectStatement { + ret := *s + return &ret +} + +// Further filter the query, instead of replacing the filter +func (q *selectStatementImpl) AndWhere( + expression BoolExpression) SelectStatement { + + if q.where == nil { + return q.Where(expression) + } + q.where = And(q.where, expression) + return q +} + +func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { + q.where = expression + return q +} + +func (q *selectStatementImpl) GroupBy( + expressions ...Expression) SelectStatement { + + q.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + q.group.clauses[i] = e + } + return q +} + +func (q *selectStatementImpl) OrderBy( + clauses ...OrderByClause) SelectStatement { + + q.order = newOrderByListClause(clauses...) + return q +} + +func (q *selectStatementImpl) Limit(limit int64) SelectStatement { + q.limit = limit + return q +} + +func (q *selectStatementImpl) Distinct() SelectStatement { + q.distinct = true + return q +} + +func (q *selectStatementImpl) WithSharedLock() SelectStatement { + // We don't need to grab a read lock if we're going to grab a write one + if !q.forUpdate { + q.withSharedLock = true + } + return q +} + +func (q *selectStatementImpl) ForUpdate() SelectStatement { + // Clear a request for a shared lock if we're asking for a write one + q.withSharedLock = false + q.forUpdate = true + return q +} + +func (q *selectStatementImpl) Offset(offset int64) SelectStatement { + q.offset = offset + return q +} + +func (q *selectStatementImpl) Comment(comment string) SelectStatement { + q.comment = comment + return q +} + +// Return the properly escaped SQL statement, against the specified database +func (q *selectStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("SELECT ") + + if err = writeComment(q.comment, buf); err != nil { + return + } + + if q.distinct { + _, _ = buf.WriteString("DISTINCT ") + } + + if q.projections == nil || len(q.projections) == 0 { + return "", errors.Newf( + "No column selected. Generated sql: %s", + buf.String()) + } + + for i, col := range q.projections { + if i > 0 { + _ = buf.WriteByte(',') + } + if col == nil { + return "", errors.Newf( + "nil column selected. Generated sql: %s", + buf.String()) + } + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + _, _ = buf.WriteString(" FROM ") + if q.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + if err = q.table.SerializeSql(database, buf); err != nil { + return + } + + if q.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = q.where.SerializeSql(buf); err != nil { + return + } + } + + if q.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = q.group.SerializeSql(buf); err != nil { + return + } + } + + if q.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = q.order.SerializeSql(buf); err != nil { + return + } + } + + if q.limit >= 0 { + if q.offset >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) + } + } + + if q.forUpdate { + _, _ = buf.WriteString(" FOR UPDATE") + } else if q.withSharedLock { + _, _ = buf.WriteString(" LOCK IN SHARE MODE") + } + + return buf.String(), nil +} + +// +// INSERT Statement ============================================================ +// + +func newInsertStatement( + t WritableTable, + columns ...NonAliasColumn) InsertStatement { + + return &insertStatementImpl{ + table: t, + columns: columns, + rows: make([][]Expression, 0, 1), + onDuplicateKeyUpdates: make([]columnAssignment, 0, 0), + } +} + +type columnAssignment struct { + col NonAliasColumn + expr Expression +} + +type insertStatementImpl struct { + table WritableTable + columns []NonAliasColumn + rows [][]Expression + onDuplicateKeyUpdates []columnAssignment + comment string + ignore bool +} + +func (s *insertStatementImpl) Add( + row ...Expression) InsertStatement { + + s.rows = append(s.rows, row) + return s +} + +func (s *insertStatementImpl) AddOnDuplicateKeyUpdate( + col NonAliasColumn, + expr Expression) InsertStatement { + + s.onDuplicateKeyUpdates = append( + s.onDuplicateKeyUpdates, + columnAssignment{col, expr}) + + return s +} + +func (s *insertStatementImpl) IgnoreDuplicates(ignore bool) InsertStatement { + s.ignore = ignore + return s +} + +func (s *insertStatementImpl) Comment(comment string) InsertStatement { + s.comment = comment + return s +} + +func (s *insertStatementImpl) String(database string) (sql string, err error) { + // Momo modified. if database empty, not validate it + if database != "" && !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("INSERT ") + if s.ignore { + _, _ = buf.WriteString("IGNORE ") + } + _, _ = buf.WriteString("INTO ") + + if err = writeComment(s.comment, buf); err != nil { + return + } + + if s.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = s.table.SerializeSql(database, buf); err != nil { + return + } + + if len(s.columns) == 0 { + return "", errors.Newf( + "No column specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" (") + for i, col := range s.columns { + if i > 0 { + _ = buf.WriteByte(',') + } + + if col == nil { + return "", errors.Newf( + "nil column in columns list. Generated sql: %s", + buf.String()) + } + + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + if len(s.rows) == 0 { + return "", errors.Newf( + "No row specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(") VALUES (") + for row_i, row := range s.rows { + if row_i > 0 { + _, _ = buf.WriteString(", (") + } + + if len(row) != len(s.columns) { + return "", errors.Newf( + "# of values does not match # of columns. Generated sql: %s", + buf.String()) + } + + for col_i, value := range row { + if col_i > 0 { + _ = buf.WriteByte(',') + } + + if value == nil { + return "", errors.Newf( + "nil value in row %d col %d. Generated sql: %s", + row_i, + col_i, + buf.String()) + } + + if err = value.SerializeSql(buf); err != nil { + return + } + } + _ = buf.WriteByte(')') + } + + if len(s.onDuplicateKeyUpdates) > 0 { + _, _ = buf.WriteString(" ON DUPLICATE KEY UPDATE ") + for i, colExpr := range s.onDuplicateKeyUpdates { + if i > 0 { + _, _ = buf.WriteString(", ") + } + + if colExpr.col == nil { + return "", errors.Newf( + ("nil column in on duplicate key update list. " + + "Generated sql: %s"), + buf.String()) + } + + if err = colExpr.col.SerializeSqlForColumnList(buf); err != nil { + return + } + + _ = buf.WriteByte('=') + + if colExpr.expr == nil { + return "", errors.Newf( + ("nil expression in on duplicate key update list. " + + "Generated sql: %s"), + buf.String()) + } + + if err = colExpr.expr.SerializeSql(buf); err != nil { + return + } + } + } + + return buf.String(), nil +} + +// +// UPDATE statement =========================================================== +// + +func newUpdateStatement(table WritableTable) UpdateStatement { + return &updateStatementImpl{ + table: table, + updateValues: make(map[NonAliasColumn]Expression), + limit: -1, + } +} + +type updateStatementImpl struct { + table WritableTable + updateValues map[NonAliasColumn]Expression + where BoolExpression + order *listClause + limit int64 + comment string +} + +func (u *updateStatementImpl) Set( + column NonAliasColumn, + expression Expression) UpdateStatement { + + u.updateValues[column] = expression + return u +} + +func (u *updateStatementImpl) Where(expression BoolExpression) UpdateStatement { + u.where = expression + return u +} + +func (u *updateStatementImpl) OrderBy( + clauses ...OrderByClause) UpdateStatement { + + u.order = newOrderByListClause(clauses...) + return u +} + +func (u *updateStatementImpl) Limit(limit int64) UpdateStatement { + u.limit = limit + return u +} + +func (u *updateStatementImpl) Comment(comment string) UpdateStatement { + u.comment = comment + return u +} + +func (u *updateStatementImpl) String(database string) (sql string, err error) { + // Momo modified. if database empty, not validate it + if database != "" && !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("UPDATE ") + + if err = writeComment(u.comment, buf); err != nil { + return + } + + if u.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = u.table.SerializeSql(database, buf); err != nil { + return + } + + if len(u.updateValues) == 0 { + return "", errors.Newf( + "No column updated. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" SET ") + addComma := false + + // Sorting is too hard in go, just create a second map ... + updateValues := make(map[string]Expression) + for col, expr := range u.updateValues { + if col == nil { + return "", errors.Newf( + "nil column. Generated sql: %s", + buf.String()) + } + + updateValues[col.Name()] = expr + } + + for _, col := range u.table.Columns() { + val, inMap := updateValues[col.Name()] + if !inMap { + continue + } + + if addComma { + _, _ = buf.WriteString(", ") + } + + if val == nil { + return "", errors.Newf( + "nil value. Generated sql: %s", + buf.String()) + } + + if err = col.SerializeSql(buf); err != nil { + return + } + + _ = buf.WriteByte('=') + if err = val.SerializeSql(buf); err != nil { + return + } + + addComma = true + } + + if u.where == nil { + return "", errors.Newf( + "Updating without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = u.where.SerializeSql(buf); err != nil { + return + } + + if u.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = u.order.SerializeSql(buf); err != nil { + return + } + } + + if u.limit >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", u.limit)) + } + + return buf.String(), nil +} + +// +// DELETE statement =========================================================== +// + +func newDeleteStatement(table WritableTable) DeleteStatement { + return &deleteStatementImpl{ + table: table, + limit: -1, + } +} + +type deleteStatementImpl struct { + table WritableTable + where BoolExpression + order *listClause + limit int64 + comment string +} + +func (d *deleteStatementImpl) Where(expression BoolExpression) DeleteStatement { + d.where = expression + return d +} + +func (d *deleteStatementImpl) OrderBy( + clauses ...OrderByClause) DeleteStatement { + + d.order = newOrderByListClause(clauses...) + return d +} + +func (d *deleteStatementImpl) Limit(limit int64) DeleteStatement { + d.limit = limit + return d +} + +func (d *deleteStatementImpl) Comment(comment string) DeleteStatement { + d.comment = comment + return d +} + +func (d *deleteStatementImpl) String(database string) (sql string, err error) { + // Momo modified. if database empty, not validate it + if database != "" && !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("DELETE FROM ") + + if err = writeComment(d.comment, buf); err != nil { + return + } + + if d.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = d.table.SerializeSql(database, buf); err != nil { + return + } + + if d.where == nil { + return "", errors.Newf( + "Deleting without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = d.where.SerializeSql(buf); err != nil { + return + } + + if d.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = d.order.SerializeSql(buf); err != nil { + return + } + } + + if d.limit >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", d.limit)) + } + + return buf.String(), nil +} + +// +// LOCK statement =========================================================== +// + +// NewLockStatement returns a SQL representing empty set of locks. You need to use +// AddReadLock/AddWriteLock to add tables that need to be locked. +// NOTE: You need at least one lock in the set for it to be a valid statement. +func NewLockStatement() LockStatement { + return &lockStatementImpl{} +} + +type lockStatementImpl struct { + locks []tableLock +} + +type tableLock struct { + t *Table + w bool +} + +// AddReadLock takes read lock on the table. +func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { + s.locks = append(s.locks, tableLock{t: t, w: false}) + return s +} + +// AddWriteLock takes write lock on the table. +func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { + s.locks = append(s.locks, tableLock{t: t, w: true}) + return s +} + +func (s *lockStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + if len(s.locks) == 0 { + return "", errors.New("No locks added") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("LOCK TABLES ") + + for idx, lock := range s.locks { + if lock.t == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = lock.t.SerializeSql(database, buf); err != nil { + return + } + + if lock.w { + _, _ = buf.WriteString(" WRITE") + } else { + _, _ = buf.WriteString(" READ") + } + + if idx != len(s.locks)-1 { + _, _ = buf.WriteString(", ") + } + } + + return buf.String(), nil +} + +// NewUnlockStatement returns SQL statement that can be used to release table locks +// grabbed by the current session. +func NewUnlockStatement() UnlockStatement { + return &unlockStatementImpl{} +} + +type unlockStatementImpl struct { +} + +func (s *unlockStatementImpl) String(database string) (sql string, err error) { + return "UNLOCK TABLES", nil +} + +// Set GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. +func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement { + return >idNextStatementImpl{ + sid: sid, + gno: gno, + } +} + +type gtidNextStatementImpl struct { + sid []byte + gno uint64 +} + +func (s *gtidNextStatementImpl) String(database string) (sql string, err error) { + // This statement sets a session local variable defining what the next transaction ID is. It + // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we + // don't have to worry about data corruption. + // Because of the string formatting (hex plus an integer), can't morph into another statement. + // See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html + const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\"" + + buf := new(bytes.Buffer) + _, _ = buf.WriteString(fmt.Sprintf(gtidFormatString, + s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno)) + return buf.String(), nil +} + +// +// Util functions ============================================================= +// + +// Once again, teisenberger is lazy. Here's a quick filter on comments +var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$") + +func isValidComment(comment string) bool { + return validCommentRegexp.MatchString(comment) +} + +func writeComment(comment string, buf *bytes.Buffer) error { + if comment != "" { + _, _ = buf.WriteString("/* ") + if !isValidComment(comment) { + return errors.Newf("Invalid comment: %s", comment) + } + _, _ = buf.WriteString(comment) + _, _ = buf.WriteString(" */") + } + return nil +} + +func newOrderByListClause(clauses ...OrderByClause) *listClause { + ret := &listClause{ + clauses: make([]Clause, len(clauses), len(clauses)), + includeParentheses: false, + } + + for i, c := range clauses { + ret.clauses[i] = c + } + + return ret +} diff --git a/table.go b/table.go new file mode 100644 index 0000000..25c7a39 --- /dev/null +++ b/table.go @@ -0,0 +1,321 @@ +// Modeling of tables. This is where query preparation starts + +package sqlbuilder + +import ( + "bytes" + "fmt" + + "github.com/dropbox/godropbox/errors" +) + +// The sql table read interface. NOTE: NATURAL JOINs, and join "USING" clause +// are not supported. +type ReadableTable interface { + // Returns the list of columns that are in the current table expression. + Columns() []NonAliasColumn + + // Generates the sql string for the current table expression. Note: the + // generated string may not be a valid/executable sql statement. + // The database is the name of the database the table is on + SerializeSql(database string, out *bytes.Buffer) error + + // Generates a select query on the current table. + Select(projections ...Projection) SelectStatement + + // Creates a inner join table expression using onCondition. + InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a left join table expression using onCondition. + LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a right join table expression using onCondition. + RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable +} + +// The sql table write interface. +type WritableTable interface { + // Returns the list of columns that are in the table. + Columns() []NonAliasColumn + + // Generates the sql string for the current table expression. Note: the + // generated string may not be a valid/executable sql statement. + // The database is the name of the database the table is on + SerializeSql(database string, out *bytes.Buffer) error + + Insert(columns ...NonAliasColumn) InsertStatement + Update() UpdateStatement + Delete() DeleteStatement +} + +// Defines a physical table in the database that is both readable and writable. +// This function will panic if name is not valid +func NewTable(name string, columns ...NonAliasColumn) *Table { + if !validIdentifierName(name) { + panic("Invalid table name") + } + + t := &Table{ + name: name, + columns: columns, + columnLookup: make(map[string]NonAliasColumn), + } + for _, c := range columns { + err := c.setTableName(name) + if err != nil { + panic(err) + } + t.columnLookup[c.Name()] = c + } + + if len(columns) == 0 { + panic(fmt.Sprintf("Table %s has no columns", name)) + } + + return t +} + +type Table struct { + name string + columns []NonAliasColumn + columnLookup map[string]NonAliasColumn + // If not empty, the name of the index to force + forcedIndex string +} + +// Returns the specified column, or errors if it doesn't exist in the table +func (t *Table) getColumn(name string) (NonAliasColumn, error) { + if c, ok := t.columnLookup[name]; ok { + return c, nil + } + return nil, errors.Newf("No such column '%s' in table '%s'", name, t.name) +} + +// Returns a pseudo column representation of the column name. Error checking +// is deferred to SerializeSql. +func (t *Table) C(name string) NonAliasColumn { + return &deferredLookupColumn{ + table: t, + colName: name, + } +} + +// Returns all columns for a table as a slice of projections +func (t *Table) Projections() []Projection { + result := make([]Projection, 0) + + for _, col := range t.columns { + result = append(result, col) + } + + return result +} + +// Returns the table's name in the database +func (t *Table) Name() string { + return t.name +} + +// Returns a list of the table's columns +func (t *Table) Columns() []NonAliasColumn { + return t.columns +} + +// Returns a copy of this table, but with the specified index forced. +func (t *Table) ForceIndex(index string) *Table { + newTable := *t + newTable.forcedIndex = index + return &newTable +} + +// Generates the sql string for the current table expression. Note: the +// generated string may not be a valid/executable sql statement. +func (t *Table) SerializeSql(database string, out *bytes.Buffer) error { + //Momo modified. if database empty, not write + if database != "" { + _, _ = out.WriteString("`") + _, _ = out.WriteString(database) + _, _ = out.WriteString("`.") + } + _, _ = out.WriteString("`") + _, _ = out.WriteString(t.Name()) + _, _ = out.WriteString("`") + + if t.forcedIndex != "" { + if !validIdentifierName(t.forcedIndex) { + return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex) + } + _, _ = out.WriteString(" FORCE INDEX (`") + _, _ = out.WriteString(t.forcedIndex) + _, _ = out.WriteString("`)") + } + + return nil +} + +// Generates a select query on the current table. +func (t *Table) Select(projections ...Projection) SelectStatement { + return newSelectStatement(t, projections) +} + +// Creates a inner join table expression using onCondition. +func (t *Table) InnerJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return InnerJoinOn(t, table, onCondition) +} + +// Creates a left join table expression using onCondition. +func (t *Table) LeftJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return LeftJoinOn(t, table, onCondition) +} + +// Creates a right join table expression using onCondition. +func (t *Table) RightJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return RightJoinOn(t, table, onCondition) +} + +func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement { + return newInsertStatement(t, columns...) +} + +func (t *Table) Update() UpdateStatement { + return newUpdateStatement(t) +} + +func (t *Table) Delete() DeleteStatement { + return newDeleteStatement(t) +} + +type joinType int + +const ( + INNER_JOIN joinType = iota + LEFT_JOIN + RIGHT_JOIN +) + +// Join expressions are pseudo readable tables. +type joinTable struct { + lhs ReadableTable + rhs ReadableTable + join_type joinType + onCondition BoolExpression +} + +func newJoinTable( + lhs ReadableTable, + rhs ReadableTable, + join_type joinType, + onCondition BoolExpression) ReadableTable { + + return &joinTable{ + lhs: lhs, + rhs: rhs, + join_type: join_type, + onCondition: onCondition, + } +} + +func InnerJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, INNER_JOIN, onCondition) +} + +func LeftJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, LEFT_JOIN, onCondition) +} + +func RightJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition) +} + +func (t *joinTable) Columns() []NonAliasColumn { + columns := make([]NonAliasColumn, 0) + columns = append(columns, t.lhs.Columns()...) + columns = append(columns, t.rhs.Columns()...) + + return columns +} + +func (t *joinTable) SerializeSql( + database string, + out *bytes.Buffer) (err error) { + + if t.lhs == nil { + return errors.Newf("nil lhs. Generated sql: %s", out.String()) + } + if t.rhs == nil { + return errors.Newf("nil rhs. Generated sql: %s", out.String()) + } + if t.onCondition == nil { + return errors.Newf("nil onCondition. Generated sql: %s", out.String()) + } + + if err = t.lhs.SerializeSql(database, out); err != nil { + return + } + + switch t.join_type { + case INNER_JOIN: + _, _ = out.WriteString(" JOIN ") + case LEFT_JOIN: + _, _ = out.WriteString(" LEFT JOIN ") + case RIGHT_JOIN: + _, _ = out.WriteString(" RIGHT JOIN ") + } + + if err = t.rhs.SerializeSql(database, out); err != nil { + return + } + + _, _ = out.WriteString(" ON ") + if err = t.onCondition.SerializeSql(out); err != nil { + return + } + + return nil +} + +func (t *joinTable) Select(projections ...Projection) SelectStatement { + return newSelectStatement(t, projections) +} + +func (t *joinTable) InnerJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return InnerJoinOn(t, table, onCondition) +} + +func (t *joinTable) LeftJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return LeftJoinOn(t, table, onCondition) +} + +func (t *joinTable) RightJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return RightJoinOn(t, table, onCondition) +} diff --git a/test_utils.go b/test_utils.go new file mode 100644 index 0000000..a7a0250 --- /dev/null +++ b/test_utils.go @@ -0,0 +1,26 @@ +package sqlbuilder + +var table1Col1 = IntColumn("col1", Nullable) +var table1Col2 = IntColumn("col2", Nullable) +var table1Col3 = IntColumn("col3", Nullable) +var table1Col4 = DateTimeColumn("col4", Nullable) +var table1 = NewTable( + "table1", + table1Col1, + table1Col2, + table1Col3, + table1Col4) + +var table2Col3 = IntColumn("col3", Nullable) +var table2Col4 = IntColumn("col4", Nullable) +var table2 = NewTable( + "table2", + table2Col3, + table2Col4) + +var table3Col1 = IntColumn("col1", Nullable) +var table3Col2 = IntColumn("col2", Nullable) +var table3 = NewTable( + "table3", + table3Col1, + table3Col2) diff --git a/types.go b/types.go new file mode 100644 index 0000000..c9d05ea --- /dev/null +++ b/types.go @@ -0,0 +1,79 @@ +package sqlbuilder + +import ( + "bytes" +) + +type Clause interface { + SerializeSql(out *bytes.Buffer) error +} + +// A clause that can be used in order by +type OrderByClause interface { + Clause + isOrderByClauseInterface +} + +// An expression +type Expression interface { + Clause + isExpressionInterface +} + +type BoolExpression interface { + Clause + isBoolExpressionInterface +} + +// A clause that is selectable. +type Projection interface { + Clause + isProjectionInterface + SerializeSqlForColumnList(out *bytes.Buffer) error +} + +// +// Boiler plates ... +// + +type isOrderByClauseInterface interface { + isOrderByClauseType() +} + +type isOrderByClause struct { +} + +func (o *isOrderByClause) isOrderByClauseType() { +} + +type isExpressionInterface interface { + isExpressionType() +} + +type isExpression struct { + isOrderByClause // can always use expression in order by. +} + +func (e *isExpression) isExpressionType() { +} + +type isBoolExpressionInterface interface { + isExpressionInterface + isBoolExpressionType() +} + +type isBoolExpression struct { +} + +func (e *isBoolExpression) isBoolExpressionType() { +} + +type isProjectionInterface interface { + isProjectionType() +} + +type isProjection struct { +} + +func (p *isProjection) isProjectionType() { +}