package mssql import ( "encoding/binary" "errors" "io" ) type packetType uint8 type header struct { PacketType packetType Status uint8 Size uint16 Spid uint16 PacketNo uint8 Pad uint8 } // tdsBuffer reads and writes TDS packets of data to the transport. // The write and read buffers are separate to make sending attn signals // possible without locks. Currently attn signals are only sent during // reads, not writes. type tdsBuffer struct { transport io.ReadWriteCloser packetSize int // Write fields. wbuf []byte wpos int wPacketSeq byte wPacketType packetType // Read fields. rbuf []byte rpos int rsize int final bool rPacketType packetType // afterFirst is assigned to right after tdsBuffer is created and // before the first use. It is executed after the first packet is // written and then removed. afterFirst func() } func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { return &tdsBuffer{ packetSize: int(bufsize), wbuf: make([]byte, 1<<16), rbuf: make([]byte, 1<<16), rpos: 8, transport: transport, } } func (rw *tdsBuffer) ResizeBuffer(packetSize int) { rw.packetSize = packetSize } func (w *tdsBuffer) PackageSize() int { return w.packetSize } func (w *tdsBuffer) flush() (err error) { // Write packet size. w.wbuf[0] = byte(w.wPacketType) binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) w.wbuf[6] = w.wPacketSeq // Write packet into underlying transport. if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil { return err } // It is possible to create a whole new buffer after a flush. // Useful for debugging. Normally reuse the buffer. // w.wbuf = make([]byte, 1<<16) // Execute afterFirst hook if it is set. if w.afterFirst != nil { w.afterFirst() w.afterFirst = nil } w.wpos = 8 w.wPacketSeq++ return nil } func (w *tdsBuffer) Write(p []byte) (total int, err error) { for { copied := copy(w.wbuf[w.wpos:w.packetSize], p) w.wpos += copied total += copied if copied == len(p) { return } if err = w.flush(); err != nil { return } p = p[copied:] } } func (w *tdsBuffer) WriteByte(b byte) error { if int(w.wpos) == len(w.wbuf) || w.wpos == w.packetSize { if err := w.flush(); err != nil { return err } } w.wbuf[w.wpos] = b w.wpos += 1 return nil } func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) { status := byte(0) if resetSession { switch packetType { // Reset session can only be set on the following packet types. case packSQLBatch, packRPCRequest, packTransMgrReq: status = 0x8 } } w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket. w.wpos = 8 w.wPacketSeq = 1 w.wPacketType = packetType } func (w *tdsBuffer) FinishPacket() error { w.wbuf[1] |= 1 // Mark this as the last packet in the message. return w.flush() } var headerSize = binary.Size(header{}) func (r *tdsBuffer) readNextPacket() error { h := header{} var err error err = binary.Read(r.transport, binary.BigEndian, &h) if err != nil { return err } if int(h.Size) > r.packetSize { return errors.New("Invalid packet size, it is longer than buffer size") } if headerSize > int(h.Size) { return errors.New("Invalid packet size, it is shorter than header size") } _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) if err != nil { return err } r.rpos = headerSize r.rsize = int(h.Size) r.final = h.Status != 0 r.rPacketType = h.PacketType return nil } func (r *tdsBuffer) BeginRead() (packetType, error) { err := r.readNextPacket() if err != nil { return 0, err } return r.rPacketType, nil } func (r *tdsBuffer) ReadByte() (res byte, err error) { if r.rpos == r.rsize { if r.final { return 0, io.EOF } err = r.readNextPacket() if err != nil { return 0, err } } res = r.rbuf[r.rpos] r.rpos++ return res, nil } func (r *tdsBuffer) byte() byte { b, err := r.ReadByte() if err != nil { badStreamPanic(err) } return b } func (r *tdsBuffer) ReadFull(buf []byte) { _, err := io.ReadFull(r, buf[:]) if err != nil { badStreamPanic(err) } } func (r *tdsBuffer) uint64() uint64 { var buf [8]byte r.ReadFull(buf[:]) return binary.LittleEndian.Uint64(buf[:]) } func (r *tdsBuffer) int32() int32 { return int32(r.uint32()) } func (r *tdsBuffer) uint32() uint32 { var buf [4]byte r.ReadFull(buf[:]) return binary.LittleEndian.Uint32(buf[:]) } func (r *tdsBuffer) uint16() uint16 { var buf [2]byte r.ReadFull(buf[:]) return binary.LittleEndian.Uint16(buf[:]) } func (r *tdsBuffer) BVarChar() string { return readBVarCharOrPanic(r) } func readBVarCharOrPanic(r io.Reader) string { s, err := readBVarChar(r) if err != nil { badStreamPanic(err) } return s } func readUsVarCharOrPanic(r io.Reader) string { s, err := readUsVarChar(r) if err != nil { badStreamPanic(err) } return s } func (r *tdsBuffer) UsVarChar() string { return readUsVarCharOrPanic(r) } func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { copied = 0 err = nil if r.rpos == r.rsize { if r.final { return 0, io.EOF } err = r.readNextPacket() if err != nil { return } } copied = copy(buf, r.rbuf[r.rpos:r.rsize]) r.rpos += copied return }