Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ func (s *SelectWithUnionQuery) statementNode() {}

// SelectIntersectExceptQuery represents SELECT ... INTERSECT/EXCEPT ... queries.
type SelectIntersectExceptQuery struct {
Position token.Position `json:"-"`
Selects []Statement `json:"selects"`
Position token.Position `json:"-"`
Selects []Statement `json:"selects"`
Operators []string `json:"operators,omitempty"` // "INTERSECT", "EXCEPT", etc. for each operator between selects
}

func (s *SelectIntersectExceptQuery) Pos() token.Position { return s.Position }
Expand Down
8 changes: 7 additions & 1 deletion internal/format/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ func formatLiteral(sb *strings.Builder, lit *ast.Literal) {
switch lit.Type {
case ast.LiteralString:
sb.WriteString("'")
// Escape single quotes in the string
// Escape backslashes and single quotes in the string
s := lit.Value.(string)
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, "'", "''")
sb.WriteString(s)
sb.WriteString("'")
Expand Down Expand Up @@ -289,6 +290,11 @@ func formatBinaryExpr(sb *strings.Builder, expr *ast.BinaryExpr) {
// formatUnaryExpr formats a unary expression.
func formatUnaryExpr(sb *strings.Builder, expr *ast.UnaryExpr) {
sb.WriteString(expr.Op)
// Add space after word operators like NOT
op := strings.ToUpper(expr.Op)
if op == "NOT" {
sb.WriteString(" ")
}
Expression(sb, expr.Operand)
}

Expand Down
4 changes: 4 additions & 0 deletions internal/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func Statement(sb *strings.Builder, stmt ast.Statement) {
formatDetachQuery(sb, s)
case *ast.AttachQuery:
formatAttachQuery(sb, s)
case *ast.ShowPrivilegesQuery:
formatShowPrivilegesQuery(sb, s)
case *ast.ShowCreateQuotaQuery:
formatShowCreateQuotaQuery(sb, s)
default:
// Fallback for unhandled statements
}
Expand Down
30 changes: 29 additions & 1 deletion internal/format/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,15 @@ func formatOrderByElement(sb *strings.Builder, o *ast.OrderByElement) {
func formatSelectIntersectExceptQuery(sb *strings.Builder, q *ast.SelectIntersectExceptQuery) {
for i, sel := range q.Selects {
if i > 0 {
sb.WriteString(" ")
// Get the operator between selects (operators[i-1] corresponds to the operator before selects[i])
opIdx := i - 1
if opIdx < len(q.Operators) {
sb.WriteString(" ")
sb.WriteString(q.Operators[opIdx])
sb.WriteString(" ")
} else {
sb.WriteString(" ")
}
}
Statement(sb, sel)
}
Expand Down Expand Up @@ -1084,3 +1092,23 @@ func formatAttachQuery(sb *strings.Builder, q *ast.AttachQuery) {
}
sb.WriteString(q.Table)
}

// formatShowPrivilegesQuery formats a SHOW PRIVILEGES statement.
func formatShowPrivilegesQuery(sb *strings.Builder, q *ast.ShowPrivilegesQuery) {
if q == nil {
return
}
sb.WriteString("SHOW PRIVILEGES")
}

// formatShowCreateQuotaQuery formats a SHOW CREATE QUOTA statement.
func formatShowCreateQuotaQuery(sb *strings.Builder, q *ast.ShowCreateQuotaQuery) {
if q == nil {
return
}
sb.WriteString("SHOW CREATE QUOTA")
if q.Name != "" {
sb.WriteString(" ")
sb.WriteString(q.Name)
}
}
Loading