package mssql import ( "context" "encoding/binary" "errors" "fmt" "io" "net" "strconv" "strings" ) //go:generate stringer -type token type token byte // token ids const ( tokenReturnStatus token = 121 // 0x79 tokenColMetadata token = 129 // 0x81 tokenOrder token = 169 // 0xA9 tokenError token = 170 // 0xAA tokenInfo token = 171 // 0xAB tokenReturnValue token = 0xAC tokenLoginAck token = 173 // 0xad tokenFeatureExtAck token = 174 // 0xae tokenRow token = 209 // 0xd1 tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 tokenSSPI token = 237 // 0xED tokenDone token = 253 // 0xFD tokenDoneProc token = 254 tokenDoneInProc token = 255 ) // done flags // https://msdn.microsoft.com/en-us/library/dd340421.aspx const ( doneFinal = 0 doneMore = 1 doneError = 2 doneInxact = 4 doneCount = 0x10 doneAttn = 0x20 doneSrvError = 0x100 ) // ENVCHANGE types // http://msdn.microsoft.com/en-us/library/dd303449.aspx const ( envTypDatabase = 1 envTypLanguage = 2 envTypCharset = 3 envTypPacketSize = 4 envSortId = 5 envSortFlags = 6 envSqlCollation = 7 envTypBeginTran = 8 envTypCommitTran = 9 envTypRollbackTran = 10 envEnlistDTC = 11 envDefectTran = 12 envDatabaseMirrorPartner = 13 envPromoteTran = 15 envTranMgrAddr = 16 envTranEnded = 17 envResetConnAck = 18 envStartedInstanceName = 19 envRouting = 20 ) // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( colFlagNullable = 1 // TODO implement more flags ) // interface for all tokens type tokenStruct interface{} type orderStruct struct { ColIds []uint16 } type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 errors []Error } func (d doneStruct) isError() bool { return d.Status&doneError != 0 || len(d.errors) > 0 } func (d doneStruct) getError() Error { if len(d.errors) > 0 { return d.errors[len(d.errors)-1] } else { return Error{Message: "Request failed but didn't provide reason"} } } type doneInProcStruct doneStruct var doneFlags2str = map[uint16]string{ doneFinal: "final", doneMore: "more", doneError: "error", doneInxact: "inxact", doneCount: "count", doneAttn: "attn", doneSrvError: "srverror", } func doneFlags2Str(flags uint16) string { strs := make([]string, 0, len(doneFlags2str)) for flag, tag := range doneFlags2str { if flags&flag != 0 { strs = append(strs, tag) } } return strings.Join(strs, "|") } // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx func processEnvChg(sess *tdsSession) { size := sess.buf.uint16() r := &io.LimitedReader{R: sess.buf, N: int64(size)} for { var err error var envtype uint8 err = binary.Read(r, binary.LittleEndian, &envtype) if err == io.EOF { return } if err != nil { badStreamPanic(err) } switch envtype { case envTypDatabase: sess.database, err = readBVarChar(r) if err != nil { badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { badStreamPanic(err) } case envTypLanguage: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypCharset: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypPacketSize: packetsize, err := readBVarChar(r) if err != nil { badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { badStreamPanic(err) } packetsizei, err := strconv.Atoi(packetsize) if err != nil { badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } sess.buf.ResizeBuffer(packetsizei) case envSortId: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envSortFlags: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envSqlCollation: // currently ignored var collationSize uint8 err = binary.Read(r, binary.LittleEndian, &collationSize) if err != nil { badStreamPanic(err) } // SQL Collation data should contain 5 bytes in length if collationSize != 5 { badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) } // 4 bytes, contains: LCID ColFlags Version var info uint32 err = binary.Read(r, binary.LittleEndian, &info) if err != nil { badStreamPanic(err) } // 1 byte, contains: sortID var sortID uint8 err = binary.Read(r, binary.LittleEndian, &sortID) if err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypBeginTran: tranid, err := readBVarByte(r) if len(tranid) != 8 { badStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) } sess.tranid = binary.LittleEndian.Uint64(tranid) if err != nil { badStreamPanic(err) } if sess.logFlags&logTransaction != 0 { sess.log.Printf("BEGIN TRANSACTION %x\n", sess.tranid) } _, err = readBVarByte(r) if err != nil { badStreamPanic(err) } case envTypCommitTran, envTypRollbackTran: _, err = readBVarByte(r) if err != nil { badStreamPanic(err) } _, err = readBVarByte(r) if err != nil { badStreamPanic(err) } if sess.logFlags&logTransaction != 0 { if envtype == envTypCommitTran { sess.log.Printf("COMMIT TRANSACTION %x\n", sess.tranid) } else { sess.log.Printf("ROLLBACK TRANSACTION %x\n", sess.tranid) } } sess.tranid = 0 case envEnlistDTC: // currently ignored // new value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envDefectTran: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envDatabaseMirrorPartner: sess.partner, err = readBVarChar(r) if err != nil { badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { badStreamPanic(err) } case envPromoteTran: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // dtc token // spec says it should be L_VARBYTE, so this code might be wrong if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTranMgrAddr: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // XACT_MANAGER_ADDRESS = B_VARBYTE if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTranEnded: // currently ignored // old value, B_VARBYTE if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envResetConnAck: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envStartedInstanceName: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // instance name if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envRouting: // RoutingData message is: // ValueLength USHORT // Protocol (TCP = 0) BYTE // ProtocolProperty (new port) USHORT // AlternateServer US_VARCHAR _, err := readUshort(r) if err != nil { badStreamPanic(err) } protocol, err := readByte(r) if err != nil || protocol != 0 { badStreamPanic(err) } newPort, err := readUshort(r) if err != nil { badStreamPanic(err) } newServer, err := readUsVarChar(r) if err != nil { badStreamPanic(err) } // consume the OLDVALUE = %x00 %x00 _, err = readUshort(r) if err != nil { badStreamPanic(err) } sess.routedServer = newServer sess.routedPort = newPort default: // ignore rest of records because we don't know how to skip those sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype) break } } } // http://msdn.microsoft.com/en-us/library/dd358180.aspx func parseReturnStatus(r *tdsBuffer) ReturnStatus { return ReturnStatus(r.int32()) } func parseOrder(r *tdsBuffer) (res orderStruct) { len := int(r.uint16()) res.ColIds = make([]uint16, len/2) for i := 0; i < len/2; i++ { res.ColIds[i] = r.uint16() } return res } // https://msdn.microsoft.com/en-us/library/dd340421.aspx func parseDone(r *tdsBuffer) (res doneStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() res.RowCount = r.uint64() return res } // https://msdn.microsoft.com/en-us/library/dd340553.aspx func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() res.RowCount = r.uint64() return res } type sspiMsg []byte func parseSSPIMsg(r *tdsBuffer) sspiMsg { size := r.uint16() buf := make([]byte, size) r.ReadFull(buf) return sspiMsg(buf) } type loginAckStruct struct { Interface uint8 TDSVersion uint32 ProgName string ProgVer uint32 } func parseLoginAck(r *tdsBuffer) loginAckStruct { size := r.uint16() buf := make([]byte, size) r.ReadFull(buf) var res loginAckStruct res.Interface = buf[0] res.TDSVersion = binary.BigEndian.Uint32(buf[1:]) prognamelen := buf[1+4] var err error if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil { badStreamPanic(err) } res.ProgVer = binary.BigEndian.Uint32(buf[size-4:]) return res } // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a func parseFeatureExtAck(r *tdsBuffer) { // at most 1 featureAck per feature in featureExt // go-mssqldb will add at most 1 feature, the spec defines 7 different features for i := 0; i < 8; i++ { featureID := r.byte() // FeatureID if featureID == 0xff { return } size := r.uint32() // FeatureAckDataLen d := make([]byte, size) r.ReadFull(d) } panic("parsed more than 7 featureAck's, protocol implementation error?") } // http://msdn.microsoft.com/en-us/library/dd357363.aspx func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent return nil } columns = make([]columnStruct, count) for i := range columns { column := &columns[i] column.UserType = r.uint32() column.Flags = r.uint16() // parsing TYPE_INFO structure column.ti = readTypeInfo(r) column.ColName = r.BVarChar() } return columns } // http://msdn.microsoft.com/en-us/library/dd357254.aspx func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { for i, column := range columns { row[i] = column.ti.Reader(&column.ti, r) } } // http://msdn.microsoft.com/en-us/library/dd304783.aspx func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { bitlen := (len(columns) + 7) / 8 pres := make([]byte, bitlen) r.ReadFull(pres) for i, col := range columns { if pres[i/8]&(1<<(uint(i)%8)) != 0 { row[i] = nil continue } row[i] = col.ti.Reader(&col.ti, r) } } // http://msdn.microsoft.com/en-us/library/dd304156.aspx func parseError72(r *tdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() res.State = r.byte() res.Class = r.byte() res.Message = r.UsVarChar() res.ServerName = r.BVarChar() res.ProcName = r.BVarChar() res.LineNo = r.int32() return } // http://msdn.microsoft.com/en-us/library/dd304156.aspx func parseInfo(r *tdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() res.State = r.byte() res.Class = r.byte() res.Message = r.UsVarChar() res.ServerName = r.BVarChar() res.ProcName = r.BVarChar() res.LineNo = r.int32() return } // https://msdn.microsoft.com/en-us/library/dd303881.aspx func parseReturnValue(r *tdsBuffer) (nv namedValue) { /* ParamOrdinal ParamName Status UserType Flags TypeInfo CryptoMetadata Value */ r.uint16() nv.Name = r.BVarChar() r.byte() r.uint32() // UserType (uint16 prior to 7.2) r.uint16() ti := readTypeInfo(r) nv.Value = ti.Reader(&ti, r) return } func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { sess.log.Printf("ERROR: Intercepted panic %v", err) } ch <- err } close(ch) }() packet_type, err := sess.buf.BeginRead() if err != nil { if sess.logFlags&logErrors != 0 { sess.log.Printf("ERROR: BeginRead failed %v", err) } ch <- err return } if packet_type != packReply { badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct errs := make([]Error, 0, 5) for { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { sess.log.Printf("got token %v", token) } switch token { case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus case tokenLoginAck: loginAck := parseLoginAck(sess.buf) ch <- loginAck case tokenFeatureExtAck: parseFeatureExtAck(sess.buf) case tokenOrder: order := parseOrder(sess.buf) ch <- order case tokenDoneInProc: done := parseDoneInProc(sess.buf) if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { sess.log.Printf("(%d row(s) affected)\n", done.RowCount) } ch <- done case tokenDone, tokenDoneProc: done := parseDone(sess.buf) done.errors = errs if sess.logFlags&logDebug != 0 { sess.log.Printf("got DONE or DONEPROC status=%d", done.Status) } if done.Status&doneSrvError != 0 { ch <- errors.New("SQL Server had internal error") return } if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { sess.log.Printf("(%d row(s) affected)\n", done.RowCount) } ch <- done if done.Status&doneMore == 0 { return } case tokenColMetadata: columns = parseColMetadata72(sess.buf) ch <- columns case tokenRow: row := make([]interface{}, len(columns)) parseRow(sess.buf, columns, row) ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) parseNbcRow(sess.buf, columns, row) ch <- row case tokenEnvChange: processEnvChg(sess) case tokenError: err := parseError72(sess.buf) if sess.logFlags&logDebug != 0 { sess.log.Printf("got ERROR %d %s", err.Number, err.Message) } errs = append(errs, err) if sess.logFlags&logErrors != 0 { sess.log.Println(err.Message) } case tokenInfo: info := parseInfo(sess.buf) if sess.logFlags&logDebug != 0 { sess.log.Printf("got INFO %d %s", info.Number, info.Message) } if sess.logFlags&logMessages != 0 { sess.log.Println(info.Message) } case tokenReturnValue: nv := parseReturnValue(sess.buf) if len(nv.Name) > 0 { name := nv.Name[1:] // Remove the leading "@". if ov, has := outs[name]; has { err = scanIntoOut(name, nv.Value, ov) if err != nil { fmt.Println("scan error", err) ch <- err } } } default: badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) } } } type parseRespIter byte const ( parseRespIterContinue parseRespIter = iota // Continue parsing current token. parseRespIterNext // Fetch the next token. parseRespIterDone // Done with parsing the response. ) type parseRespState byte const ( parseRespStateNormal parseRespState = iota // Normal response state. parseRespStateCancel // Query is canceled, wait for server to confirm. parseRespStateClosing // Waiting for tokens to come through. ) type parseResp struct { sess *tdsSession ctxDone <-chan struct{} state parseRespState cancelError error } func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter { if err := sendAttention(ts.sess.buf); err != nil { ts.dlogf("failed to send attention signal %v", err) ch <- err return parseRespIterDone } ts.state = parseRespStateCancel return parseRespIterContinue } func (ts *parseResp) dlog(msg string) { if ts.sess.logFlags&logDebug != 0 { ts.sess.log.Println(msg) } } func (ts *parseResp) dlogf(f string, v ...interface{}) { if ts.sess.logFlags&logDebug != 0 { ts.sess.log.Printf(f, v...) } } func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter { switch ts.state { default: panic("unknown state") case parseRespStateNormal: select { case tok, ok := <-tokChan: if !ok { ts.dlog("response finished") return parseRespIterDone } if err, ok := tok.(net.Error); ok && err.Timeout() { ts.cancelError = err ts.dlog("got timeout error, sending attention signal to server") return ts.sendAttention(ch) } // Pass the token along. ch <- tok return parseRespIterContinue case <-ts.ctxDone: ts.ctxDone = nil ts.dlog("got cancel message, sending attention signal to server") return ts.sendAttention(ch) } case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth select { case tok, ok := <-tokChan: if !ok { ts.dlog("response finished but waiting for attention ack") return parseRespIterNext } switch tok := tok.(type) { default: // Ignore all other tokens while waiting. // The TDS spec says other tokens may arrive after an attention // signal is sent. Ignore these tokens and continue looking for // a DONE with attention confirm mark. case doneStruct: if tok.Status&doneAttn != 0 { ts.dlog("got cancellation confirmation from server") if ts.cancelError != nil { ch <- ts.cancelError ts.cancelError = nil } else { ch <- ctx.Err() } return parseRespIterDone } // If an error happens during cancel, pass it along and just stop. // We are uncertain to receive more tokens. case error: ch <- tok ts.state = parseRespStateClosing } return parseRespIterContinue case <-ts.ctxDone: ts.ctxDone = nil ts.state = parseRespStateClosing return parseRespIterContinue } case parseRespStateClosing: // Wait for current token chan to close. if _, ok := <-tokChan; !ok { ts.dlog("response finished") return parseRespIterDone } return parseRespIterContinue } } func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { ts := &parseResp{ sess: sess, ctxDone: ctx.Done(), } defer func() { // Ensure any remaining error is piped through // or the query may look like it executed when it actually failed. if ts.cancelError != nil { ch <- ts.cancelError ts.cancelError = nil } close(ch) }() // Loop over multiple responses. for { ts.dlog("initiating response reading") tokChan := make(chan tokenStruct) go processSingleResponse(sess, tokChan, outs) // Loop over multiple tokens in response. tokensLoop: for { switch ts.iter(ctx, ch, tokChan) { case parseRespIterContinue: // Nothing, continue to next token. case parseRespIterNext: break tokensLoop case parseRespIterDone: return } } } }