Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return ErrorResponse on query block #186

Merged
merged 4 commits into from
May 30, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 144 additions & 62 deletions decryptor/postgresql/pg_decryptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,49 @@ import (
log "github.com/sirupsen/logrus"
)

// ReadForQuery - 'Z' ReadyForQuery, 0 0 0 5 length, 'I' idle status
// https://www.postgresql.org/docs/9.3/static/protocol-message-formats.html
var ReadyForQueryPacket = []byte{'Z', 0, 0, 0, 5, 'I'}

func NewPgError(message string) ([]byte, error) {
// 5 = E marker + 4 bytes for message length
// 7 is severity error with null terminator
// +1 for null terminator of message and packet
output := make([]byte, 5+7+7+len(message)+2)
// error message
output[0] = 'E'
// leave untouched place for length of data
output = output[:5]
// error severity
output = append(output, []byte{'S', 'E', 'R', 'R', 'O', 'R', 0}...)
// 42000 - syntax_error_or_access_rule_violation
// https://www.postgresql.org/docs/9.3/static/errcodes-appendix.html
output = append(output, []byte("C42000")...)
output = append(output, 0)
// human readable message
output = append(output, append([]byte{'M'}, []byte(message)...)...)
output = append(output, 0, 0)
// place length of data
// -1 byte to exclude type of message
// 1:5 4 bytes for packet length without first byte of message type
binary.BigEndian.PutUint32(output[1:5], uint32(len(output)-1))
return output, nil
}

type DataRow struct {
buf [1]byte
output []byte
messageType [1]byte
descriptionLengthBuf []byte
columnSizePointer []byte
columnDataBuf *bytes.Buffer
writeIndex int
columnCount int
dataLength int
errCh chan<- error
reader *acra_io.ExtendedBufferedReader
writer *bufio.Writer
descriptionBuf *bytes.Buffer

output []byte
columnSizePointer []byte
columnDataBuf *bytes.Buffer
writeIndex int
columnCount int
dataLength int
errCh chan<- error
reader *acra_io.ExtendedBufferedReader
writer *bufio.Writer
}

const (
Expand Down Expand Up @@ -94,25 +125,24 @@ func (row *DataRow) skipData(reader io.Reader, writer io.Writer, errCh chan<- er
return true
}

func (row *DataRow) readByte(reader io.Reader, writer io.Writer, errCh chan<- error) bool {
n, err := reader.Read(row.buf[:])
func (row *DataRow) readMessageType(reader io.Reader, writer io.Writer, errCh chan<- error) bool {
n, err := reader.Read(row.messageType[:])
if !base.CheckReadWrite(n, 1, err, errCh) {
return false
}
log.Printf("byte=%v, '%v'", row.buf[0], string(row.buf[0]))
n, err = writer.Write(row.buf[:])
n, err = writer.Write(row.messageType[:])
if !base.CheckReadWrite(n, 1, err, errCh) {
return false
}
return true
}

func (row *DataRow) IsDataRow() bool {
return row.buf[0] == DATA_ROW_MESSAGE_TYPE
return row.messageType[0] == DATA_ROW_MESSAGE_TYPE
}

func (row *DataRow) IsSimpleQuery() bool {
return row.buf[0] == QUERY_MESSAGE_TYPE
return row.messageType[0] == QUERY_MESSAGE_TYPE
}

func (row *DataRow) UpdateColumnAndDataSize(oldColumnLength, newColumnLength int) bool {
Expand Down Expand Up @@ -146,7 +176,6 @@ func (row *DataRow) ReadDataLength() bool {
}
row.writeIndex += n
row.dataLength = int(binary.BigEndian.Uint32(row.output[:DATA_ROW_LENGTH_BUF_SIZE])) - len(row.descriptionLengthBuf)
log.Printf("data length buf=%v, %v", row.output[:DATA_ROW_LENGTH_BUF_SIZE], row.dataLength)
return true
}

Expand Down Expand Up @@ -195,6 +224,43 @@ func (row *DataRow) ReadSimpleQuery(errCh chan<- error) (string, bool) {
return string(query), success
}

var ErrShortRead = errors.New("read less bytes than expected")

func (row *DataRow) ReadRow(reader io.Reader, errCh chan<- error) (*DataRow, error) {
n, err := reader.Read(row.messageType[:])
if err != nil {
return nil, err
}
if n != 1 {
return nil, ErrShortRead
}
n, err = reader.Read(row.descriptionLengthBuf)
if err != nil {
return nil, err
}
if n != len(row.descriptionLengthBuf) {
return nil, ErrShortRead
}
row.dataLength = int(binary.BigEndian.Uint32(row.descriptionLengthBuf)) - len(row.descriptionLengthBuf)
row.descriptionBuf.Reset()
nn, err := io.CopyN(row.descriptionBuf, reader, int64(row.dataLength))
if err != nil {
return nil, err
}
if nn != int64(row.dataLength) {
return nil, ErrShortRead
}
return row, nil
}

func (row *DataRow) Marshal() ([]byte, error) {
output := make([]byte, 0, 5+row.dataLength)
output = append(output, row.messageType[0])
output = append(output, row.descriptionLengthBuf...)
output = append(output, row.descriptionBuf.Bytes()...)
return output, nil
}

type PgProxy struct {
clientConnection net.Conn
dbConnection net.Conn
Expand All @@ -206,45 +272,32 @@ func NewPgProxy(clientConnection, dbConnection net.Conn, errCh chan<- error) (*P
return &PgProxy{clientConnection: clientConnection, dbConnection: dbConnection, errCh: errCh, TlsCh: make(chan bool)}, nil
}

func NewPgError(message string) ([]byte, error) {
// 5 = E marker + 4 bytes for message length
// 7 is severity error with null terminator
// +1 for null terminator of message
output := make([]byte, 5+7+7+len(message)+1)
// error message
output[0] = 'E'
// leave untouched place for length of data
output = output[:5]
// error severity
output = append(output, []byte{'S', 'E', 'R', 'R', 'O', 'R', 0}...)
// 42501 insufficient_privilege
// https://www.postgresql.org/docs/9.3/static/errcodes-appendix.html
output = append(output, []byte("C42501")...)
output = append(output, 0)
// human readable message
output = append(output, append([]byte{'M'}, []byte(message)...)...)
output = append(output, 0)
// place length of data
binary.BigEndian.PutUint32(output[1:5], uint32(len(output)-1))
return output, nil
func NewClientSideDataRow(reader *acra_io.ExtendedBufferedReader, writer *bufio.Writer) (*DataRow, error) {
return &DataRow{
writeIndex: 0,
output: nil,
columnDataBuf: nil,
descriptionBuf: bytes.NewBuffer(make([]byte, OUTPUT_DEFAULT_SIZE)),
descriptionLengthBuf: make([]byte, 4),
reader: reader,
writer: writer,
}, nil
}

func (proxy *PgProxy) PgProxyClientRequests(acraCensor acracensor.AcraCensorInterface, dbConnection, clientConnection net.Conn, errCh chan<- error) {
log.Debugln("pg client proxy")
writer := bufio.NewWriter(dbConnection)

reader := acra_io.NewExtendedBufferedReader(bufio.NewReader(clientConnection))
row := DataRow{
writeIndex: 0,
output: make([]byte, OUTPUT_DEFAULT_SIZE),
columnDataBuf: bytes.NewBuffer(make([]byte, COLUMN_DATA_DEFAULT_SIZE)),
descriptionLengthBuf: make([]byte, 4),
reader: reader,
writer: writer,
row, err := NewClientSideDataRow(reader, writer)
if err != nil {
log.WithError(err).Errorln("can't initialize DataRow object")
errCh <- err
return
}
firstByte := true
for {
row.writeIndex = 0
row.descriptionBuf.Reset()
if firstByte {
log.Debugln("first packet")
// first packet hasn't type of message and start with message length and data
Expand All @@ -259,28 +312,36 @@ func (proxy *PgProxy) PgProxyClientRequests(acraCensor acracensor.AcraCensorInte
}
continue
}
if !row.readByte(reader, writer, errCh) {
log.Debugln("can't read byte")
row, err := row.ReadRow(reader, errCh)
if err != nil {
log.WithError(err).Errorln("can't read row")
errCh <- err
return
}
if !row.IsSimpleQuery() {
log.Debugln("not query")
if !row.skipData(reader, writer, errCh) {
output, err := row.Marshal()
if err != nil {
log.WithError(err).Errorln("Can't dump row")
errCh <- err
return

}
n, err := writer.Write(output)
if !base.CheckReadWrite(n, len(output), err, errCh) {
return
}
if err := writer.Flush(); err != nil {
log.WithError(err).Errorln("can't flush writer")
errCh <- err
return
}
writer.Flush()
continue
}
log.Debugln("query packet")
query, success := row.ReadSimpleQuery(errCh)
if !success {
row.Flush()
return
}
query := string(row.descriptionBuf.Bytes()[:row.dataLength-1])
log.WithField("query", query).Debugln("new query")
if censorErr := acraCensor.HandleQuery(query); censorErr != nil {
log.WithError(censorErr).Errorln("AcraCensor blocked query")
errCh <- censorErr
errorMessage, err := NewPgError("AcraCensor blocked this query")
if err != nil {
log.WithError(err).Errorln("can't create postgresql error message")
Expand All @@ -290,19 +351,40 @@ func (proxy *PgProxy) PgProxyClientRequests(acraCensor acracensor.AcraCensorInte
if !base.CheckReadWrite(n, len(errorMessage), err, row.errCh) {
return
}
n, err = clientConnection.Write(ReadyForQueryPacket)
if !base.CheckReadWrite(n, len(ReadyForQueryPacket), err, row.errCh) {
return
}
continue
}

output, err := row.Marshal()
if err != nil {
log.WithError(err).Errorln("Can't dump row")
errCh <- err
return

}
if !row.Flush() {
n, err := writer.Write(output)
if !base.CheckReadWrite(n, len(output), err, errCh) {
return
}
if err := writer.Flush(); err != nil {
log.WithError(err).Errorln("can't flush writer to db")
log.WithError(err).Errorln("can't flush writer")
errCh <- err
return
}
}
}

func (row *DataRow) IsSSLRequest() bool {
return row.messageType[0] == 'S'
}

func (row *DataRow) IsSSLRequestDeny() bool {
return row.messageType[0] == 'N'
}

func (proxy *PgProxy) PgDecryptStream(censor acracensor.AcraCensorInterface, decryptor base.Decryptor, tlsConfig *tls.Config, dbConnection net.Conn, clientConnection net.Conn, errCh chan<- error) {
log.Debugln("pg db proxy")
writer := bufio.NewWriter(clientConnection)
Expand All @@ -318,19 +400,19 @@ func (proxy *PgProxy) PgDecryptStream(censor acracensor.AcraCensorInterface, dec
}
firstByte := true
for {
if !row.readByte(reader, writer, errCh) {
if !row.readMessageType(reader, writer, errCh) {
return
}

if firstByte {
// https://www.postgresql.org/docs/9.1/static/protocol-flow.html#AEN92112
// we should know that we shouldn't read anymore bytes
firstByte = false
if row.buf[0] == 'N' {
if row.IsSSLRequestDeny() {
log.Debugln("deny ssl request")
writer.Flush()
continue
} else if row.buf[0] == 'S' {
} else if row.IsSSLRequest() {
if tlsConfig == nil {
log.Errorln("To support TLS connections you must pass TLS key and certificate for AcraServer that will be used" +
"for connections AcraServer->Database and CA certificate which will be used to verify certificate " +
Expand Down