diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index dbdbe252b3..4e9c8b02d2 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -3,6 +3,7 @@ package compiler import ( "errors" "fmt" + "math/rand" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" @@ -596,10 +597,14 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro if err != nil { return nil, err } + rel := &ast.TableName{} + if n.Alias != nil && n.Alias.Aliasname != nil { + rel.Name = *n.Alias.Aliasname + } else { + rel.Name = fmt.Sprintf("unnamed_subquery_%d", rand.Int63()) + } tables = append(tables, &Table{ - Rel: &ast.TableName{ - Name: *n.Alias.Aliasname, - }, + Rel: rel, Columns: cols, }) diff --git a/internal/compiler/query_catalog.go b/internal/compiler/query_catalog.go index 80b59d876c..29abace6e2 100644 --- a/internal/compiler/query_catalog.go +++ b/internal/compiler/query_catalog.go @@ -2,6 +2,7 @@ package compiler import ( "fmt" + "math/rand" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/catalog" @@ -9,13 +10,15 @@ import ( ) type QueryCatalog struct { - catalog *catalog.Catalog - ctes map[string]*Table - embeds rewrite.EmbedSet + catalog *catalog.Catalog + ctes map[string]*Table + fromClauses map[string]*Table + embeds rewrite.EmbedSet } func (comp *Compiler) buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) { var with *ast.WithClause + var from *ast.List switch n := node.(type) { case *ast.DeleteStmt: with = n.WithClause @@ -23,12 +26,20 @@ func (comp *Compiler) buildQueryCatalog(c *catalog.Catalog, node ast.Node, embed with = n.WithClause case *ast.UpdateStmt: with = n.WithClause + from = n.FromClause case *ast.SelectStmt: with = n.WithClause + from = n.FromClause default: with = nil + from = nil + } + qc := &QueryCatalog{ + catalog: c, + ctes: map[string]*Table{}, + fromClauses: map[string]*Table{}, + embeds: embeds, } - qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}, embeds: embeds} if with != nil { for _, item := range with.Ctes.Items { if cte, ok := item.(*ast.CommonTableExpr); ok { @@ -60,6 +71,42 @@ func (comp *Compiler) buildQueryCatalog(c *catalog.Catalog, node ast.Node, embed } } } + if from != nil { + for _, item := range from.Items { + if rs, ok := item.(*ast.RangeSubselect); ok { + cols, err := comp.outputColumns(qc, rs.Subquery) + if err != nil { + return nil, err + } + var names []string + if rs.Alias != nil && rs.Alias.Colnames != nil { + for _, item := range rs.Alias.Colnames.Items { + if val, ok := item.(*ast.String); ok { + names = append(names, val.Str) + } else { + names = append(names, "") + } + } + } + rel := &ast.TableName{} + if rs.Alias != nil && rs.Alias.Aliasname != nil { + rel.Name = *rs.Alias.Aliasname + } else { + rel.Name = fmt.Sprintf("unaliased_table_%d", rand.Int63()) + } + for i := range cols { + cols[i].Table = rel + if len(names) > i { + cols[i].Name = names[i] + } + } + qc.fromClauses[rel.Name] = &Table{ + Rel: rel, + Columns: cols, + } + } + } + } return qc, nil } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b1fbb1990e..49fe52e898 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -80,6 +80,30 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, aliasMap[*rv.Alias.Aliasname] = fqn } } + if qc != nil { + for _, f := range qc.fromClauses { + catCols := make([]*catalog.Column, 0, len(f.Columns)) + for _, col := range f.Columns { + catCols = append(catCols, &catalog.Column{ + Name: col.Name, + Type: ast.TypeName{Name: col.DataType}, + IsNotNull: col.NotNull, + IsUnsigned: col.Unsigned, + IsArray: col.IsArray, + ArrayDims: col.ArrayDims, + Comment: col.Comment, + Length: col.Length, + }) + } + + if err := indexTable(catalog.Table{ + Rel: f.Rel, + Columns: catCols, + }); err != nil { + return nil, err + } + } + } // resolve a table for an embed for _, embed := range embeds { diff --git a/internal/endtoend/testdata/join_alias/mysql/go/query.sql.go b/internal/endtoend/testdata/join_alias/mysql/go/query.sql.go index 836f62c6d4..1ac7681d5e 100644 --- a/internal/endtoend/testdata/join_alias/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/join_alias/mysql/go/query.sql.go @@ -80,3 +80,84 @@ func (q *Queries) AliasJoin(ctx context.Context, id uint64) ([]AliasJoinRow, err } return items, nil } + +const columnAlias = `-- name: ColumnAlias :many +SELECT n FROM (SELECT 1 AS n) WHERE n <= ? +` + +func (q *Queries) ColumnAlias(ctx context.Context, n int32) ([]int32, error) { + rows, err := q.db.QueryContext(ctx, columnAlias, n) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int32 + for rows.Next() { + var n int32 + if err := rows.Scan(&n); err != nil { + return nil, err + } + items = append(items, n) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const columnAndQueryAlias = `-- name: ColumnAndQueryAlias :many +SELECT n FROM (SELECT 1 AS n) AS x WHERE n <= ? +` + +func (q *Queries) ColumnAndQueryAlias(ctx context.Context, n int32) ([]int32, error) { + rows, err := q.db.QueryContext(ctx, columnAndQueryAlias, n) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int32 + for rows.Next() { + var n int32 + if err := rows.Scan(&n); err != nil { + return nil, err + } + items = append(items, n) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const subqueryAlias = `-- name: SubqueryAlias :many +SELECT n FROM (SELECT 1 AS n) AS x WHERE x.n <= ? +` + +func (q *Queries) SubqueryAlias(ctx context.Context, n int32) ([]int32, error) { + rows, err := q.db.QueryContext(ctx, subqueryAlias, n) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int32 + for rows.Next() { + var n int32 + if err := rows.Scan(&n); err != nil { + return nil, err + } + items = append(items, n) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/join_alias/mysql/query.sql b/internal/endtoend/testdata/join_alias/mysql/query.sql index 9b087bcae7..77ed3cd3ad 100644 --- a/internal/endtoend/testdata/join_alias/mysql/query.sql +++ b/internal/endtoend/testdata/join_alias/mysql/query.sql @@ -9,3 +9,12 @@ SELECT * FROM foo f JOIN bar b ON b.id = f.id WHERE f.id = ?; + +-- name: SubqueryAlias :many +SELECT * FROM (SELECT 1 AS n) AS x WHERE x.n <= ?; + +-- name: ColumnAlias :many +SELECT * FROM (SELECT 1 AS n) WHERE n <= ?; + +-- name: ColumnAndQueryAlias :many +SELECT * FROM (SELECT 1 AS n) AS x WHERE n <= ?; diff --git a/internal/endtoend/testdata/select_subquery_alias/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/select_subquery_alias/postgresql/pgx/go/query.sql.go index 8744642eb0..0cb154dba8 100644 --- a/internal/endtoend/testdata/select_subquery_alias/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/select_subquery_alias/postgresql/pgx/go/query.sql.go @@ -24,8 +24,8 @@ where amounts.last_balance < $2 ` type FindWalletsParams struct { - Column1 pgtype.Text - Column2 pgtype.Numeric + Type string + LastBalance pgtype.Numeric } type FindWalletsRow struct { @@ -36,7 +36,7 @@ type FindWalletsRow struct { } func (q *Queries) FindWallets(ctx context.Context, arg FindWalletsParams) ([]FindWalletsRow, error) { - rows, err := q.db.Query(ctx, findWallets, arg.Column1, arg.Column2) + rows, err := q.db.Query(ctx, findWallets, arg.Type, arg.LastBalance) if err != nil { return nil, err }