44 "context"
55 "database/sql"
66 "database/sql/driver"
7+ "errors"
78 "fmt"
89 "io"
910 "net"
@@ -1859,34 +1860,34 @@ func TestStmtQueryContext(t *testing.T) {
18591860 defer db .Close ()
18601861
18611862 tests := []struct {
1862- name string
1863- ctx func () (context.Context , context.CancelFunc )
1864- sql string
1865- err error
1863+ name string
1864+ ctx func () (context.Context , context.CancelFunc )
1865+ sql string
1866+ cancelExpected bool
18661867 }{
18671868 {
18681869 name : "context.Background" ,
18691870 ctx : func () (context.Context , context.CancelFunc ) {
18701871 return context .Background (), nil
18711872 },
1872- sql : "SELECT pg_sleep(1);" ,
1873- err : nil ,
1873+ sql : "SELECT pg_sleep(1);" ,
1874+ cancelExpected : false ,
18741875 },
18751876 {
18761877 name : "context.WithTimeout exceeded" ,
18771878 ctx : func () (context.Context , context.CancelFunc ) {
18781879 return context .WithTimeout (context .Background (), 1 * time .Second )
18791880 },
1880- sql : "SELECT pg_sleep(10);" ,
1881- err : & Error { Message : "canceling statement due to user request" } ,
1881+ sql : "SELECT pg_sleep(10);" ,
1882+ cancelExpected : true ,
18821883 },
18831884 {
18841885 name : "context.WithTimeout" ,
18851886 ctx : func () (context.Context , context.CancelFunc ) {
18861887 return context .WithTimeout (context .Background (), time .Minute )
18871888 },
1888- sql : "SELECT pg_sleep(1);" ,
1889- err : nil ,
1889+ sql : "SELECT pg_sleep(1);" ,
1890+ cancelExpected : false ,
18901891 },
18911892 }
18921893 for _ , tt := range tests {
@@ -1900,11 +1901,12 @@ func TestStmtQueryContext(t *testing.T) {
19001901 t .Fatal (err )
19011902 }
19021903 _ , err = stmt .QueryContext (ctx )
1904+ pgErr := (* Error )(nil )
19031905 switch {
1904- case (err != nil ) != ( tt .err != nil ) :
1905- t .Fatalf ("stmt.QueryContext() unexpected nil err got = %v, expected = %v" , err , tt .err )
1906- case (err != nil && tt .err != nil ) && ( err . Error () != tt . err . Error () ):
1907- t .Errorf ("stmt.QueryContext() got = %v, expected = %v" , err .Error (), tt .err . Error () )
1906+ case (err != nil ) != tt .cancelExpected :
1907+ t .Fatalf ("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v" , err , tt .cancelExpected )
1908+ case (err != nil && tt .cancelExpected ) && ! ( errors . As ( err , & pgErr ) && pgErr . Code == cancelErrorCode ):
1909+ t .Errorf ("stmt.QueryContext() got = %v, cancelExpected = %v" , err .Error (), tt .cancelExpected )
19081910 }
19091911 })
19101912 }
@@ -1915,34 +1917,34 @@ func TestStmtExecContext(t *testing.T) {
19151917 defer db .Close ()
19161918
19171919 tests := []struct {
1918- name string
1919- ctx func () (context.Context , context.CancelFunc )
1920- sql string
1921- err error
1920+ name string
1921+ ctx func () (context.Context , context.CancelFunc )
1922+ sql string
1923+ cancelExpected bool
19221924 }{
19231925 {
19241926 name : "context.Background" ,
19251927 ctx : func () (context.Context , context.CancelFunc ) {
19261928 return context .Background (), nil
19271929 },
1928- sql : "SELECT pg_sleep(1);" ,
1929- err : nil ,
1930+ sql : "SELECT pg_sleep(1);" ,
1931+ cancelExpected : false ,
19301932 },
19311933 {
19321934 name : "context.WithTimeout exceeded" ,
19331935 ctx : func () (context.Context , context.CancelFunc ) {
19341936 return context .WithTimeout (context .Background (), 1 * time .Second )
19351937 },
1936- sql : "SELECT pg_sleep(10);" ,
1937- err : & Error { Message : "canceling statement due to user request" } ,
1938+ sql : "SELECT pg_sleep(10);" ,
1939+ cancelExpected : true ,
19381940 },
19391941 {
19401942 name : "context.WithTimeout" ,
19411943 ctx : func () (context.Context , context.CancelFunc ) {
19421944 return context .WithTimeout (context .Background (), time .Minute )
19431945 },
1944- sql : "SELECT pg_sleep(1);" ,
1945- err : nil ,
1946+ sql : "SELECT pg_sleep(1);" ,
1947+ cancelExpected : false ,
19461948 },
19471949 }
19481950 for _ , tt := range tests {
@@ -1956,11 +1958,12 @@ func TestStmtExecContext(t *testing.T) {
19561958 t .Fatal (err )
19571959 }
19581960 _ , err = stmt .ExecContext (ctx )
1961+ pgErr := (* Error )(nil )
19591962 switch {
1960- case (err != nil ) != ( tt .err != nil ) :
1961- t .Fatalf ("stmt.ExecContext () unexpected nil err got = %v, expected = %v" , err , tt .err )
1962- case (err != nil && tt .err != nil ) && ( err . Error () != tt . err . Error () ):
1963- t .Errorf ("stmt.ExecContext () got = %v, expected = %v" , err .Error (), tt .err . Error () )
1963+ case (err != nil ) != tt .cancelExpected :
1964+ t .Fatalf ("stmt.QueryContext () unexpected nil err got = %v, cancelExpected = %v" , err , tt .cancelExpected )
1965+ case (err != nil && tt .cancelExpected ) && ! ( errors . As ( err , & pgErr ) && pgErr . Code == cancelErrorCode ):
1966+ t .Errorf ("stmt.QueryContext () got = %v, cancelExpected = %v" , err .Error (), tt .cancelExpected )
19641967 }
19651968 })
19661969 }
0 commit comments