Skip to content

Commit 1ef134d

Browse files
authored
Merge pull request #1074 from lib/fix-assertions
Avoid asserting on error message for cancel tests
2 parents 8446d16 + 4b55993 commit 1ef134d

3 files changed

Lines changed: 45 additions & 37 deletions

File tree

conn_test.go

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"errors"
78
"fmt"
89
"io"
910
"net"
@@ -1859,34 +1860,34 @@ func TestStmtQueryContext(t *testing.T) {
18591860
defer db.Close()
18601861

18611862
tests := []struct {
1862-
name string
1863-
ctx func() (context.Context, context.CancelFunc)
1864-
sql string
1865-
err error
1863+
name string
1864+
ctx func() (context.Context, context.CancelFunc)
1865+
sql string
1866+
cancelExpected bool
18661867
}{
18671868
{
18681869
name: "context.Background",
18691870
ctx: func() (context.Context, context.CancelFunc) {
18701871
return context.Background(), nil
18711872
},
1872-
sql: "SELECT pg_sleep(1);",
1873-
err: nil,
1873+
sql: "SELECT pg_sleep(1);",
1874+
cancelExpected: false,
18741875
},
18751876
{
18761877
name: "context.WithTimeout exceeded",
18771878
ctx: func() (context.Context, context.CancelFunc) {
18781879
return context.WithTimeout(context.Background(), 1*time.Second)
18791880
},
1880-
sql: "SELECT pg_sleep(10);",
1881-
err: &Error{Message: "canceling statement due to user request"},
1881+
sql: "SELECT pg_sleep(10);",
1882+
cancelExpected: true,
18821883
},
18831884
{
18841885
name: "context.WithTimeout",
18851886
ctx: func() (context.Context, context.CancelFunc) {
18861887
return context.WithTimeout(context.Background(), time.Minute)
18871888
},
1888-
sql: "SELECT pg_sleep(1);",
1889-
err: nil,
1889+
sql: "SELECT pg_sleep(1);",
1890+
cancelExpected: false,
18901891
},
18911892
}
18921893
for _, tt := range tests {
@@ -1900,11 +1901,12 @@ func TestStmtQueryContext(t *testing.T) {
19001901
t.Fatal(err)
19011902
}
19021903
_, err = stmt.QueryContext(ctx)
1904+
pgErr := (*Error)(nil)
19031905
switch {
1904-
case (err != nil) != (tt.err != nil):
1905-
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err)
1906-
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
1907-
t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error())
1906+
case (err != nil) != tt.cancelExpected:
1907+
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
1908+
case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
1909+
t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
19081910
}
19091911
})
19101912
}
@@ -1915,34 +1917,34 @@ func TestStmtExecContext(t *testing.T) {
19151917
defer db.Close()
19161918

19171919
tests := []struct {
1918-
name string
1919-
ctx func() (context.Context, context.CancelFunc)
1920-
sql string
1921-
err error
1920+
name string
1921+
ctx func() (context.Context, context.CancelFunc)
1922+
sql string
1923+
cancelExpected bool
19221924
}{
19231925
{
19241926
name: "context.Background",
19251927
ctx: func() (context.Context, context.CancelFunc) {
19261928
return context.Background(), nil
19271929
},
1928-
sql: "SELECT pg_sleep(1);",
1929-
err: nil,
1930+
sql: "SELECT pg_sleep(1);",
1931+
cancelExpected: false,
19301932
},
19311933
{
19321934
name: "context.WithTimeout exceeded",
19331935
ctx: func() (context.Context, context.CancelFunc) {
19341936
return context.WithTimeout(context.Background(), 1*time.Second)
19351937
},
1936-
sql: "SELECT pg_sleep(10);",
1937-
err: &Error{Message: "canceling statement due to user request"},
1938+
sql: "SELECT pg_sleep(10);",
1939+
cancelExpected: true,
19381940
},
19391941
{
19401942
name: "context.WithTimeout",
19411943
ctx: func() (context.Context, context.CancelFunc) {
19421944
return context.WithTimeout(context.Background(), time.Minute)
19431945
},
1944-
sql: "SELECT pg_sleep(1);",
1945-
err: nil,
1946+
sql: "SELECT pg_sleep(1);",
1947+
cancelExpected: false,
19461948
},
19471949
}
19481950
for _, tt := range tests {
@@ -1956,11 +1958,12 @@ func TestStmtExecContext(t *testing.T) {
19561958
t.Fatal(err)
19571959
}
19581960
_, err = stmt.ExecContext(ctx)
1961+
pgErr := (*Error)(nil)
19591962
switch {
1960-
case (err != nil) != (tt.err != nil):
1961-
t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err)
1962-
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
1963-
t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error())
1963+
case (err != nil) != tt.cancelExpected:
1964+
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
1965+
case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
1966+
t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
19641967
}
19651968
})
19661969
}

go18_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"errors"
78
"runtime"
89
"strings"
910
"testing"
@@ -75,6 +76,8 @@ func TestMultipleSimpleQuery(t *testing.T) {
7576

7677
const contextRaceIterations = 100
7778

79+
const cancelErrorCode ErrorCode = "57014"
80+
7881
func TestContextCancelExec(t *testing.T) {
7982
db := openTestConn(t)
8083
defer db.Close()
@@ -87,7 +90,7 @@ func TestContextCancelExec(t *testing.T) {
8790
// Not canceled until after the exec has started.
8891
if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
8992
t.Fatal("expected error")
90-
} else if err.Error() != "pq: canceling statement due to user request" {
93+
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
9194
t.Fatalf("unexpected error: %s", err)
9295
}
9396

@@ -125,7 +128,7 @@ func TestContextCancelQuery(t *testing.T) {
125128
// Not canceled until after the exec has started.
126129
if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
127130
t.Fatal("expected error")
128-
} else if err.Error() != "pq: canceling statement due to user request" {
131+
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
129132
t.Fatalf("unexpected error: %s", err)
130133
}
131134

@@ -215,7 +218,7 @@ func TestContextCancelBegin(t *testing.T) {
215218
// Not canceled until after the exec has started.
216219
if _, err := tx.Exec("select pg_sleep(1)"); err == nil {
217220
t.Fatal("expected error")
218-
} else if err.Error() != "pq: canceling statement due to user request" {
221+
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
219222
t.Fatalf("unexpected error: %s", err)
220223
}
221224

@@ -240,8 +243,8 @@ func TestContextCancelBegin(t *testing.T) {
240243
cancel()
241244
if err != nil {
242245
t.Fatal(err)
243-
} else if err := tx.Rollback(); err != nil &&
244-
err.Error() != "pq: canceling statement due to user request" &&
246+
} else if err, pgErr := tx.Rollback(), (*Error)(nil); err != nil &&
247+
!(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) &&
245248
err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled {
246249
t.Fatal(err)
247250
}

issues_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pq
22

33
import (
44
"context"
5+
"errors"
56
"testing"
67
"time"
78
)
@@ -51,10 +52,9 @@ func TestIssue1046(t *testing.T) {
5152
t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since)
5253
t.Fail()
5354
}
54-
expectedErr := &Error{Message: "canceling statement due to user request"}
55-
if err == nil || err.Error() != expectedErr.Error() {
55+
if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
5656
t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err())
57-
t.Logf("got err: [%T] %+v expected err: [%T] %+v", err, err, expectedErr, expectedErr)
57+
t.Logf("got err: [%T] %+v expected errCode: %v", err, err, cancelErrorCode)
5858
t.Fail()
5959
}
6060
}
@@ -72,7 +72,9 @@ func TestIssue1062(t *testing.T) {
7272

7373
var v int
7474
err := row.Scan(&v)
75-
if err != nil && err != context.Canceled && err.Error() != "pq: canceling statement due to user request" {
75+
if pgErr := (*Error)(nil); err != nil &&
76+
err != context.Canceled &&
77+
!(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
7678
t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1)
7779
}
7880
}

0 commit comments

Comments
 (0)