First commit
This commit is contained in:
9
LICENSE
Normal file
9
LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Fedorenko Dmitrij
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
25
Makefile
Normal file
25
Makefile
Normal file
@@ -0,0 +1,25 @@
|
||||
GOARCH := $(shell go env GOARCH)
|
||||
GOOS := $(shell go env GOOS)
|
||||
|
||||
GOLANGCI_LINT_VERSION := 1.16.0
|
||||
GOLANGCI_LINT_ARCHIVE_NAME := golangci-lint-${GOLANGCI_LINT_VERSION}-${GOOS}-${GOARCH}
|
||||
GOLANGCI_LINT_URL := https://github.com/golangci/golangci-lint/releases/download/v${GOLANGCI_LINT_VERSION}/${GOLANGCI_LINT_ARCHIVE_NAME}.tar.gz
|
||||
|
||||
export PATH := $(PWD)/bin:$(PATH)
|
||||
export GO111MODULE=on
|
||||
|
||||
default: lint test
|
||||
|
||||
bin/${GOLANGCI_LINT_ARCHIVE_NAME}/:
|
||||
mkdir -p bin
|
||||
curl -L ${GOLANGCI_LINT_URL} | tar --directory bin/ --gzip --extract --verbose
|
||||
|
||||
bin/golangci-lint: bin/${GOLANGCI_LINT_ARCHIVE_NAME}/
|
||||
ln -f -s $(PWD)/bin/${GOLANGCI_LINT_ARCHIVE_NAME}/golangci-lint $@
|
||||
touch $@
|
||||
|
||||
lint: bin/golangci-lint
|
||||
golangci-lint run ./...
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
70
README.md
Normal file
70
README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# TODO
|
||||
REWRITE THIS
|
||||
only proxyprotocol.proxyprotocol.NewListener(rawListener) works, the sourcechecker and the logger are not there anymore or the choice to define the proxy protocol version
|
||||
|
||||
# go-proxyprotocol
|
||||
|
||||
[](https://godoc.org/github.com/c0va23/go-proxyprotocol)
|
||||
[](https://goreportcard.com/report/github.com/c0va23/go-proxyprotocol)
|
||||
[](https://travis-ci.org/c0va23/go-proxyprotocol)
|
||||
|
||||
Golang package `github.com/c0va23/go-proxyprotocol' provide receiver for
|
||||
[HA ProxyProtocol v1 and v2](http://www.haproxy.org/download/2.0/doc/proxy-protocol.txt).
|
||||
|
||||
This package provides a wrapper for the interface net.Listener, which extracts
|
||||
remote and local address of the connection from the headers in the format
|
||||
HA proxyprotocol.
|
||||
|
||||
## Usage example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/c0va23/go-proxyprotocol"
|
||||
)
|
||||
|
||||
func main() {
|
||||
rawList, _ := net.Listen("tcp", ":8080")
|
||||
|
||||
list := proxyprotocol.
|
||||
NewDefaultListener(rawList).
|
||||
WithLogger(proxyprotocol.LoggerFunc(log.Printf))
|
||||
|
||||
http.Serve(list, http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
log.Printf("Remote Addr: %s, URI: %s", req.RemoteAddr, req.RequestURI)
|
||||
fmt.Fprintf(res, "Hello, %s!\n", req.RemoteAddr)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
DefaultListener try parse proxyprotocol v1 and v2 header. If header signature
|
||||
not recognized, then used raw connection.
|
||||
|
||||
If you want to use only proxy protocol V1 or v2 headers, you can initialize the
|
||||
listener as follows:
|
||||
|
||||
```go
|
||||
list := proxyprotocol.NewListener(rawList, proxyprotocol.TextHeaderParserBuilder)
|
||||
```
|
||||
|
||||
## Implementation status
|
||||
|
||||
### Human-readable header format (Version 1)
|
||||
- [x] UNKNOWN
|
||||
- [x] IPv4
|
||||
- [x] IPv6
|
||||
|
||||
### Binary header format (version 2)
|
||||
- [x] Unspec
|
||||
- [x] TCP over IPv4
|
||||
- [x] TCP over IPv6
|
||||
- [ ] UDP over IPv4
|
||||
- [ ] UDP over IPv6
|
||||
- [ ] Unix Stream
|
||||
- [ ] Unix Datagram
|
||||
58
binary.go
Normal file
58
binary.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package proxyprotocol
|
||||
|
||||
// BinarySignature is magic prefix for proxyprotocol Binary
|
||||
var (
|
||||
BinarySignature = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
|
||||
BinarySignatureLen = len(BinarySignature)
|
||||
)
|
||||
|
||||
// BinaryVersion2 bits
|
||||
const (
|
||||
BinaryVersion2 byte = 0x20
|
||||
BinaryVersionMask byte = 0xF0
|
||||
)
|
||||
|
||||
// Commands
|
||||
const (
|
||||
BinaryCommandLocal byte = 0x00
|
||||
BinaryCommandProxy byte = 0x01
|
||||
BinaryCommandMask byte = 0x0F
|
||||
)
|
||||
|
||||
// Address families
|
||||
const (
|
||||
BinaryAFUnspec byte = 0x00
|
||||
BinaryAFInet byte = 0x10
|
||||
BinaryAFInet6 byte = 0x20
|
||||
BinaryAFUnix byte = 0x30
|
||||
BinaryAFMask byte = 0xF0
|
||||
)
|
||||
|
||||
// Transport protocols
|
||||
const (
|
||||
BinaryTPUnspec byte = 0x00
|
||||
BinaryTPStream byte = 0x01
|
||||
BinaryTPDgram byte = 0x02
|
||||
BinaryTPMask byte = 0x0F
|
||||
)
|
||||
|
||||
// Protocol variants
|
||||
var (
|
||||
BinaryProtocolUnspec = BinaryAFUnspec | BinaryTPUnspec
|
||||
BinaryProtocolTCPoverIPv4 = BinaryAFInet | BinaryTPStream
|
||||
BinaryProtocolUDPoverIPv4 = BinaryAFInet | BinaryTPDgram
|
||||
BinaryProtocolTCPoverIPv6 = BinaryAFInet6 | BinaryTPStream
|
||||
BinaryProtocolUDPoverIPv6 = BinaryAFInet6 | BinaryTPDgram
|
||||
BinaryProtocolUnixStream = BinaryAFUnix | BinaryTPStream
|
||||
BinaryProtocolUnixDatagram = BinaryAFUnix | BinaryTPDgram
|
||||
)
|
||||
|
||||
// Expected address length
|
||||
var (
|
||||
BinaryPortLen = 2
|
||||
)
|
||||
|
||||
// TLV types
|
||||
const (
|
||||
TLVTypeNoop byte = 0x04
|
||||
)
|
||||
132
binary_receive.go
Normal file
132
binary_receive.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package proxyprotocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrUnknownVersion = errors.New("unknown version")
|
||||
ErrUnknownCommand = errors.New("unknown command")
|
||||
ErrUnexpectedAddressLen = errors.New("unexpected address length")
|
||||
)
|
||||
|
||||
// Meta buffer byte position
|
||||
const (
|
||||
versionCommandPos = 0
|
||||
protocolPos = 1
|
||||
addressLenStartPos = 2
|
||||
addressLenEndPos = 4
|
||||
)
|
||||
|
||||
// BinaryHeaderParser parse proxyprotocol header from Reader
|
||||
type BinaryHeaderParser struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// NewBinaryHeaderParser construct BinaryHeaderParser
|
||||
func NewBinaryHeaderParser(logger Logger) BinaryHeaderParser {
|
||||
return BinaryHeaderParser{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Parse buffer
|
||||
func (parser BinaryHeaderParser) Parse(buf *bufio.Reader) (*Header, error) {
|
||||
magicBuf, err := buf.Peek(BinarySignatureLen)
|
||||
if err != nil {
|
||||
parser.logger.Printf("Read magic prefix error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bytes.Equal(magicBuf, BinarySignature) {
|
||||
return nil, ErrInvalidSignature
|
||||
}
|
||||
|
||||
_, err = buf.Discard(BinarySignatureLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaBuf := make([]byte, addressLenEndPos)
|
||||
if _, err = buf.Read(metaBuf); err != nil {
|
||||
parser.logger.Printf("Read meta error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
versionCommandByte := metaBuf[versionCommandPos]
|
||||
|
||||
if versionCommandByte&BinaryVersionMask != BinaryVersion2 {
|
||||
return nil, ErrUnknownVersion
|
||||
}
|
||||
|
||||
addressSizeBuf := metaBuf[addressLenStartPos:addressLenEndPos]
|
||||
addressesLen := int(binary.BigEndian.Uint16(addressSizeBuf))
|
||||
parser.logger.Printf("Addresses len: %d", addressesLen)
|
||||
|
||||
addressesBuf := make([]byte, addressesLen)
|
||||
addressReaded, err := buf.Read(addressesBuf)
|
||||
if err != nil {
|
||||
parser.logger.Printf("Read address error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
parser.logger.Printf("Address readed: %d", addressReaded)
|
||||
|
||||
switch versionCommandByte & BinaryCommandMask {
|
||||
case BinaryCommandProxy:
|
||||
return parserBinaryCommandHeader(metaBuf[protocolPos], addressesBuf)
|
||||
case BinaryCommandLocal:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, ErrUnknownCommand
|
||||
}
|
||||
}
|
||||
|
||||
func parserBinaryCommandHeader(protocol byte, addressesBuf []byte) (*Header, error) {
|
||||
switch protocol & BinaryAFMask {
|
||||
case BinaryProtocolUnspec:
|
||||
return nil, nil
|
||||
case BinaryAFInet:
|
||||
return parseAddressData(addressesBuf, net.IPv4len)
|
||||
case BinaryAFInet6:
|
||||
return parseAddressData(addressesBuf, net.IPv6len)
|
||||
default:
|
||||
return nil, ErrUnknownProtocol
|
||||
}
|
||||
}
|
||||
|
||||
func parseAddressData(addressesBuf []byte, ipLen int) (*Header, error) {
|
||||
expectedBufSize := 2 * (ipLen + BinaryPortLen)
|
||||
if len(addressesBuf) < expectedBufSize {
|
||||
return nil, ErrUnexpectedAddressLen
|
||||
}
|
||||
|
||||
srcIP := make(net.IP, ipLen)
|
||||
copy(srcIP, addressesBuf[:ipLen])
|
||||
addressesBuf = addressesBuf[ipLen:]
|
||||
|
||||
dstIP := make(net.IP, ipLen)
|
||||
copy(dstIP, addressesBuf[:ipLen])
|
||||
addressesBuf = addressesBuf[ipLen:]
|
||||
|
||||
srcPort := binary.BigEndian.Uint16(addressesBuf[:BinaryPortLen])
|
||||
addressesBuf = addressesBuf[BinaryPortLen:]
|
||||
|
||||
dstPort := binary.BigEndian.Uint16(addressesBuf[:BinaryPortLen])
|
||||
// addressesBuf = addressesBuf[BinaryPortLen:]
|
||||
|
||||
return &Header{
|
||||
SrcAddr: &net.TCPAddr{
|
||||
IP: srcIP,
|
||||
Port: int(srcPort),
|
||||
},
|
||||
DstAddr: &net.TCPAddr{
|
||||
IP: dstIP,
|
||||
Port: int(dstPort),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
86
conn.go
Normal file
86
conn.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package proxyprotocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Conn is wrapper on net.Conn with RemoteAddr() override.
|
||||
//
|
||||
// On first call Read() or RemoteAddr() parse proxyprotocol header and store
|
||||
// local and remote addresses.
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
logger Logger
|
||||
readBuf *bufio.Reader
|
||||
header *Header
|
||||
headerErr error
|
||||
headerParser HeaderParser
|
||||
trustedAddr bool
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewConn create wrapper on net.Conn.
|
||||
func NewConn(conn net.Conn, headerParser HeaderParser) net.Conn {
|
||||
readBuf := bufio.NewReaderSize(conn, bufferSize)
|
||||
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
readBuf: readBuf,
|
||||
logger: nil,
|
||||
headerParser: headerParser,
|
||||
trustedAddr: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) parseHeader() {
|
||||
conn.header, conn.headerErr = conn.headerParser.Parse(conn.readBuf)
|
||||
if conn.headerErr != nil {
|
||||
conn.logger.Printf("Header parse error: %s", conn.headerErr)
|
||||
return
|
||||
}
|
||||
conn.logger.Printf("Header parsed %v", conn.header)
|
||||
}
|
||||
|
||||
// Read on first call parse proxyprotocol header.
|
||||
//
|
||||
// If header parser return error, then error stored and returned. Otherwise call
|
||||
// Read on source connection.
|
||||
//
|
||||
// Following calls of Read function check parse header error.
|
||||
// If error not nil, then error returned. Otherwise called source "conn.Read".
|
||||
func (conn *Conn) Read(buf []byte) (int, error) {
|
||||
conn.once.Do(conn.parseHeader)
|
||||
|
||||
if conn.headerErr != nil {
|
||||
return 0, conn.headerErr
|
||||
}
|
||||
|
||||
return conn.readBuf.Read(buf)
|
||||
}
|
||||
|
||||
// LocalAddr proxy to conn.LocalAddr
|
||||
func (conn *Conn) LocalAddr() net.Addr {
|
||||
conn.once.Do(conn.parseHeader)
|
||||
|
||||
if conn.trustedAddr && conn.header != nil {
|
||||
return conn.header.DstAddr
|
||||
}
|
||||
|
||||
return conn.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr on first call parse proxyprotocol header.
|
||||
//
|
||||
// If header parser return header, then return source address from header.
|
||||
// Otherwise return original source address.
|
||||
func (conn *Conn) RemoteAddr() net.Addr {
|
||||
conn.once.Do(conn.parseHeader)
|
||||
|
||||
if conn.trustedAddr && conn.header != nil {
|
||||
return conn.header.SrcAddr
|
||||
}
|
||||
|
||||
return conn.Conn.RemoteAddr()
|
||||
}
|
||||
84
fallback_receive.go
Normal file
84
fallback_receive.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package proxyprotocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// StubHeaderParser always return nil Header
|
||||
type StubHeaderParser struct{}
|
||||
|
||||
// NewStubHeaderParser construct StubHeaderParser
|
||||
func NewStubHeaderParser() StubHeaderParser {
|
||||
return StubHeaderParser{}
|
||||
}
|
||||
|
||||
// Parse always return nil, nil
|
||||
func (parser StubHeaderParser) Parse(*bufio.Reader) (*Header, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// FallbackHeaderParserBuilder build FallbackHeaderParser
|
||||
type FallbackHeaderParserBuilder []HeaderParserBuilder
|
||||
|
||||
// NewFallbackHeaderParserBuilder construct FallbackHeaderParserBuilder
|
||||
func NewFallbackHeaderParserBuilder(
|
||||
headerParserBuilders ...HeaderParserBuilder,
|
||||
) FallbackHeaderParserBuilder {
|
||||
return FallbackHeaderParserBuilder(headerParserBuilders)
|
||||
}
|
||||
|
||||
// Build FallbackHeaderParser from headerParserBuilders
|
||||
func (headerParserBuilders FallbackHeaderParserBuilder) Build(logger Logger) HeaderParser {
|
||||
headerParsers := make([]HeaderParser, 0, len(headerParserBuilders))
|
||||
for _, headerParserBuilder := range headerParserBuilders {
|
||||
headerParser := headerParserBuilder.Build(logger)
|
||||
headerParsers = append(headerParsers, headerParser)
|
||||
}
|
||||
return FallbackHeaderParser{
|
||||
Logger: logger,
|
||||
HeaderParsers: headerParsers,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidHeader returned by FallbackHeaderParser when all headerParsers return
|
||||
// ErrInvalidSignature
|
||||
var ErrInvalidHeader = errors.New("invalid header")
|
||||
|
||||
// FallbackHeaderParser iterate over HeaderParser until parser not return nil error.
|
||||
type FallbackHeaderParser struct {
|
||||
Logger Logger
|
||||
HeaderParsers []HeaderParser
|
||||
}
|
||||
|
||||
// NewFallbackHeaderParser create new instance of FallbackHeaderParser
|
||||
func NewFallbackHeaderParser(logger Logger, headerParsers ...HeaderParser) FallbackHeaderParser {
|
||||
return FallbackHeaderParser{
|
||||
Logger: logger,
|
||||
HeaderParsers: headerParsers,
|
||||
}
|
||||
}
|
||||
|
||||
// Parse iterate over headerParsers call Parse().
|
||||
//
|
||||
// If any parser return not nil or not ErrInvalidSignature error, then return its error.
|
||||
//
|
||||
// If any parser return nil error, then return header.
|
||||
//
|
||||
// If all parsers return error ErrInvalidSignature, then return ErrInvalidHeader.
|
||||
func (parser FallbackHeaderParser) Parse(buf *bufio.Reader) (*Header, error) {
|
||||
for _, headerParser := range parser.HeaderParsers {
|
||||
header, err := headerParser.Parse(buf)
|
||||
switch err {
|
||||
case nil:
|
||||
parser.Logger.Printf("Use header remote addr")
|
||||
return header, nil
|
||||
case ErrInvalidSignature:
|
||||
continue
|
||||
default:
|
||||
parser.Logger.Printf("Parse header error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, ErrInvalidHeader
|
||||
}
|
||||
5
go.mod
Normal file
5
go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module github.com/c0va23/go-proxyprotocol
|
||||
|
||||
go 1.12
|
||||
|
||||
require github.com/golang/mock v1.2.0
|
||||
2
go.sum
Normal file
2
go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
|
||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
78
listener.go
Normal file
78
listener.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package proxyprotocol
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
const bufferSize = 1400
|
||||
|
||||
// SourceChecker check trusted address
|
||||
type SourceChecker func(net.Addr) (bool, error)
|
||||
|
||||
// NewListener construct Listener
|
||||
func NewListener(listener net.Listener) Listener {
|
||||
return Listener{
|
||||
Listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
// Listener implement net.Listener
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
}
|
||||
|
||||
// HeaderParserBuilderFunc wrap builder func into HeaderParserBuilder
|
||||
type HeaderParserBuilderFunc func(logger Logger) HeaderParser
|
||||
|
||||
// Build implement HeaderParserBuilder for build func
|
||||
func (funcBuilder HeaderParserBuilderFunc) Build(logger Logger) HeaderParser {
|
||||
return funcBuilder(logger)
|
||||
}
|
||||
|
||||
// TextHeaderParserBuilder build TextHeaderParser
|
||||
var TextHeaderParserBuilder = HeaderParserBuilderFunc(func(logger Logger) HeaderParser {
|
||||
return NewTextHeaderParser(logger)
|
||||
})
|
||||
|
||||
// BinaryHeaderParserBuilder build BinaryHeaderParser
|
||||
var BinaryHeaderParserBuilder = HeaderParserBuilderFunc(func(logger Logger) HeaderParser {
|
||||
return NewBinaryHeaderParser(logger)
|
||||
})
|
||||
|
||||
// StubHeaderParserBuilder build StubHeaderParser
|
||||
var StubHeaderParserBuilder = HeaderParserBuilderFunc(func(logger Logger) HeaderParser {
|
||||
return NewStubHeaderParser()
|
||||
})
|
||||
|
||||
// Otherwise connection wrapped into Conn with header parser.
|
||||
func (listener Listener) Accept() (net.Conn, error) {
|
||||
rawConn, err := listener.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger := FallbackLogger{Logger: nil}
|
||||
// trusted := true
|
||||
// if listener.SourceChecker != nil {
|
||||
// trusted, err = listener.SourceChecker(rawConn.RemoteAddr())
|
||||
// if err != nil {
|
||||
// logger.Printf("Source check error: %s", err)
|
||||
// return nil, err
|
||||
// }
|
||||
// }
|
||||
|
||||
// if trusted {
|
||||
// logger.Printf("Trusted connection")
|
||||
// } else {
|
||||
// logger.Printf("Not trusted connection")
|
||||
// }
|
||||
|
||||
// NOTE strictly a proxy protocol implementation without plain connections
|
||||
headerParser := NewFallbackHeaderParserBuilder(
|
||||
TextHeaderParserBuilder,
|
||||
BinaryHeaderParserBuilder,
|
||||
// StubHeaderParserBuilder,
|
||||
).Build(logger)
|
||||
|
||||
return NewConn(rawConn, headerParser), nil
|
||||
}
|
||||
27
logger.go
Normal file
27
logger.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package proxyprotocol
|
||||
|
||||
// Logger interface
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// LoggerFunc wrap Printf-like function into proxyprotocol.Logger
|
||||
type LoggerFunc func(format string, v ...interface{})
|
||||
|
||||
// Printf call inner Printf-link function
|
||||
func (logf LoggerFunc) Printf(format string, v ...interface{}) {
|
||||
logf(format, v...)
|
||||
}
|
||||
|
||||
// FallbackLogger wrap Logger or nil
|
||||
type FallbackLogger struct {
|
||||
Logger
|
||||
}
|
||||
|
||||
// Printf call Printf on inner logger if it not nil
|
||||
func (wrapper FallbackLogger) Printf(format string, v ...interface{}) {
|
||||
if wrapper.Logger == nil {
|
||||
return
|
||||
}
|
||||
wrapper.Logger.Printf(format, v...)
|
||||
}
|
||||
36
proxyprotocol.go
Normal file
36
proxyprotocol.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Package proxyprotocol implement receiver for HA Proxy Protocol V1 and V2.
|
||||
//
|
||||
// Proxy Protocol spec http://www.haproxy.org/download/2.0/doc/proxy-protocol.txt
|
||||
//
|
||||
// This package provides a wrapper for the interface net.Listener, which extracts
|
||||
// remote and local address of the connection from the headers in the format
|
||||
// HA proxy protocol.
|
||||
package proxyprotocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Header struct represent header parsing result
|
||||
type Header struct {
|
||||
SrcAddr net.Addr
|
||||
DstAddr net.Addr
|
||||
}
|
||||
|
||||
// HeaderParserBuilder build HeaderParser's
|
||||
type HeaderParserBuilder interface {
|
||||
Build(Logger) HeaderParser
|
||||
}
|
||||
|
||||
// HeaderParser describe interface for header parsers
|
||||
type HeaderParser interface {
|
||||
Parse(readBuf *bufio.Reader) (*Header, error)
|
||||
}
|
||||
|
||||
// Shared HeaderParser errors
|
||||
var (
|
||||
ErrInvalidSignature = errors.New("invalid signature")
|
||||
ErrUnknownProtocol = errors.New("unknown protocol")
|
||||
)
|
||||
23
text.go
Normal file
23
text.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package proxyprotocol
|
||||
|
||||
// TextSignature is prefix for proxyprotocol v1
|
||||
var (
|
||||
TextSignature = []byte("PROXY")
|
||||
TextSeparator = " "
|
||||
TextCR = byte('\r')
|
||||
TextLF = byte('\n')
|
||||
TextCRLF = []byte{TextCR, TextLF}
|
||||
)
|
||||
|
||||
var (
|
||||
textSignatureLen = len(TextSignature)
|
||||
textAddressPartsLen = 4
|
||||
textPortBitSize = 16
|
||||
)
|
||||
|
||||
// TextProtocol list
|
||||
var (
|
||||
TextProtocolIPv4 = "TCP4"
|
||||
TextProtocolIPv6 = "TCP6"
|
||||
TextProtocolUnknown = "UNKNOWN"
|
||||
)
|
||||
108
text_receive.go
Normal file
108
text_receive.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package proxyprotocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Text protocol errors
|
||||
var (
|
||||
ErrInvalidAddressList = errors.New("invalid address list")
|
||||
ErrInvalidIP = errors.New("invalid IP")
|
||||
ErrInvalidPort = errors.New("invalid port")
|
||||
)
|
||||
|
||||
// TextHeaderParser for proxyprotocol v1
|
||||
type TextHeaderParser struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// NewTextHeaderParser construct TextHeaderParser
|
||||
func NewTextHeaderParser(logger Logger) TextHeaderParser {
|
||||
return TextHeaderParser{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Parse proxyprotocol v1 header
|
||||
func (parser TextHeaderParser) Parse(buf *bufio.Reader) (*Header, error) {
|
||||
signatureBuf, err := buf.Peek(textSignatureLen)
|
||||
if err != nil {
|
||||
parser.logger.Printf("Read text signature error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !bytes.Equal(signatureBuf, TextSignature) {
|
||||
return nil, ErrInvalidSignature
|
||||
}
|
||||
|
||||
headerLine, err := buf.ReadString(TextLF)
|
||||
if err != nil {
|
||||
parser.logger.Printf("Read header line error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Strip CR char on line end
|
||||
if headerLine[len(headerLine)-2] == TextCR {
|
||||
headerLine = headerLine[:len(headerLine)-2]
|
||||
}
|
||||
|
||||
headerParts := strings.Split(headerLine, TextSeparator)
|
||||
|
||||
protocol := headerParts[1]
|
||||
|
||||
switch protocol {
|
||||
case TextProtocolUnknown:
|
||||
return nil, nil
|
||||
case TextProtocolIPv4, TextProtocolIPv6:
|
||||
return parseTextHeader(headerParts)
|
||||
default:
|
||||
return nil, ErrUnknownProtocol
|
||||
}
|
||||
}
|
||||
|
||||
func parseTextHeader(headerParts []string) (*Header, error) {
|
||||
addressParts := headerParts[2:]
|
||||
if textAddressPartsLen != len(addressParts) {
|
||||
return nil, ErrInvalidAddressList
|
||||
}
|
||||
|
||||
srcIPStr := addressParts[0]
|
||||
srcIP := net.ParseIP(srcIPStr)
|
||||
if srcIP == nil {
|
||||
return nil, ErrInvalidIP
|
||||
}
|
||||
|
||||
dstIPStr := addressParts[1]
|
||||
dstIP := net.ParseIP(dstIPStr)
|
||||
if dstIP == nil {
|
||||
return nil, ErrInvalidIP
|
||||
}
|
||||
|
||||
srcPortSrt := addressParts[2]
|
||||
srcPort, err := strconv.ParseUint(srcPortSrt, 10, textPortBitSize)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidPort
|
||||
}
|
||||
|
||||
dstPortSrt := addressParts[3]
|
||||
dstPort, err := strconv.ParseUint(dstPortSrt, 10, textPortBitSize)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidPort
|
||||
}
|
||||
|
||||
return &Header{
|
||||
SrcAddr: &net.TCPAddr{
|
||||
IP: srcIP,
|
||||
Port: int(srcPort),
|
||||
},
|
||||
DstAddr: &net.TCPAddr{
|
||||
IP: dstIP,
|
||||
Port: int(dstPort),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user