Skip to content

Commit

Permalink
Permit a scalar type for T in iter.SQL. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
bobg authored Jul 7, 2024
1 parent 83d5429 commit 2cd7a2b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 6 deletions.
42 changes: 37 additions & 5 deletions iter/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ type QueryerContext interface {
}

// SQL performs a query against db and returns the results as an iterator of type T.
// T must be a struct type whose fields have the same types,
//
// If the query produces a single value per row,
// T may be any scalar type (bool, int, float, string)
// into which the values can be scanned.
//
// Otherwise T must be a struct type whose fields have the same types,
// in the same order,
// as the values being queried.
// The values produced by the iterator will be instances of that struct type,
Expand Down Expand Up @@ -50,12 +55,21 @@ func (e sqlKindError) Error() string {
func sqlhelper[T any](ctx context.Context, rows *sql.Rows) (Of[T], error) {
var t T
tt := reflect.TypeOf(t)
if tt.Kind() != reflect.Struct {
switch tt.Kind() {
case reflect.Struct:
return sqlhelperStruct[T](ctx, tt, rows), nil

case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
return sqlhelperScalar[T](ctx, tt, rows), nil

default:
return nil, sqlKindError{kind: tt.Kind()}
}
nfields := tt.NumField()
}

res := Go(func(ch chan<- T) error {
func sqlhelperStruct[T any](ctx context.Context, tt reflect.Type, rows *sql.Rows) Of[T] {
nfields := tt.NumField()
return Go(func(ch chan<- T) error {
defer rows.Close()

for rows.Next() {
Expand Down Expand Up @@ -89,6 +103,24 @@ func sqlhelper[T any](ctx context.Context, rows *sql.Rows) (Of[T], error) {
}
return rows.Err()
})
}

func sqlhelperScalar[T any](ctx context.Context, tt reflect.Type, rows *sql.Rows) Of[T] {
return Go(func(ch chan<- T) error {
defer rows.Close()

return res, nil
for rows.Next() {
var val T
if err := rows.Scan(&val); err != nil {
return fmt.Errorf("scanning row: %w", err)
}

select {
case <-ctx.Done():
return ctx.Err()
case ch <- val:
}
}
return rows.Err()
})
}
17 changes: 16 additions & 1 deletion iter/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ func TestSQL(t *testing.T) {

const q = `SELECT name, salary FROM employees ORDER BY name`

t.Run("Scalar", func(t *testing.T) {
it, err := SQL[string](ctx, db, `SELECT name FROM employees ORDER BY name DESC`)
if err != nil {
t.Fatal(err)
}
got, err := ToSlice(it)
if err != nil {
t.Fatal(err)
}
want := []string{"dave", "carol", "bill", "alice"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
})

t.Run("SQL", func(t *testing.T) {
it, err := SQL[employee](ctx, db, q)
if err != nil {
Expand Down Expand Up @@ -103,7 +118,7 @@ func TestSQL(t *testing.T) {
})

t.Run("KindError", func(t *testing.T) {
_, err := SQL[int](ctx, db, q)
_, err := SQL[*int](ctx, db, q)

var e sqlKindError
if !errors.As(err, &e) {
Expand Down

0 comments on commit 2cd7a2b

Please sign in to comment.