Open
Show file tree
Hide file tree
Changes from all commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Failed to load files.
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
module .com/go-sql-driver/mysql

go 1.18

require golang.org/x/sys v0.10.0 // indirect
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Original file line numberDiff line numberDiff line change
Expand Up@@ -14,6 +14,7 @@ import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
Expand DownExpand Up@@ -44,12 +45,24 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {

// check packet sync [8 bit]
if data[3] != mc.sequence {
var syncErr error
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
syncErr = ErrPktSyncMul
} else {
syncErr = ErrPktSync
}
return nil, ErrPktSync

if prevData != nil {
return nil, syncErr
} else {
// log and ignore seqno mismatch error.
// MySQL sometimes sends wrong sequence no.
mc.cfg.Logger.Print(syncErr)
mc.sequence = data[3] + 1
}
} else {
mc.sequence++
}
mc.sequence++

// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long
Expand DownExpand Up@@ -89,6 +102,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}
}

// used in conncheck.go
var errUnexpectedEvent = errors.New("recieved unexpected event")

// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4
Expand All@@ -111,18 +127,29 @@ func (mc *mysqlConn) writePacket(data []byte) error {
}
var err error
if mc.cfg.CheckConnLiveness {
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
err = connCheck(conn)
if err != nil {
if err == errUnexpectedEvent {
_ = conn.SetReadDeadline(time.Now().Add(time.Second))
var data []byte
data, err = mc.readPacket()

if err == nil {
if data[0] == iERR {
err = mc.handleErrorPacket(data)
} else {
err = fmt.Errorf("unexpected packet: % x", data[:128])
}
} else {
err = fmt.Errorf("readPacket(): %w", err)
}
}

mc.cfg.Logger.Print("checkConn() failed: ", err)
mc.Close()
return driver.ErrBadConn
}
}
if err != nil {
mc.cfg.Logger.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
}

for {
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -11,6 +11,7 @@ package mysql
import (
"bytes"
"errors"
"fmt"
"net"
"testing"
"time"
Expand DownExpand Up@@ -132,31 +133,57 @@ func TestReadPacketSingleByte(t *testing.T) {
}
}

type mockLogger struct {
bytes.Buffer
}

func (ml *mockLogger) Print(v ...any) {
ml.WriteString(fmt.Sprint(v...) + "\n")
}

func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
}
logger := &mockLogger{}
mc.cfg.Logger = Logger(logger)

// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
mc.sequence = 1
_, err := mc.readPacket()
if err != ErrPktSync {
t.Errorf("expected ErrPktSync, got %v", err)
data, err := mc.readPacket()
if err != nil {
t.Errorf("expected nil, got %v", err)
}
if len(data) != 1 || data[0] != 0xff {
t.Errorf("expected [0xff], got % x", data)
}
logMsg := logger.String()
if logMsg != ErrPktSync.Error()+"\n" {
t.Errorf("expected ErrPktSync.Error(), got %q", logMsg)
}

// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
logger.Reset()

// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
_, err = mc.readPacket()
if err != ErrPktSyncMul {
t.Errorf("expected ErrPktSyncMul, got %v", err)
data, err = mc.readPacket()
if err != nil {
t.Errorf("expected nil, got %v", err)
}
if len(data) != 1 || data[0] != 0xff {
t.Errorf("expected [0xff], got % x", data)
}
logMsg = logger.String()
if logMsg != ErrPktSyncMul.Error()+"\n" {
t.Errorf("expected ErrPktSync.Error(), got %q", logMsg)
}
}

Expand Down