87 lines
2.0 KiB
Go
87 lines
2.0 KiB
Go
package proxyprotocol
|
|
|
|
import (
|
|
"bufio"
|
|
"net"
|
|
"sync"
|
|
)
|
|
|
|
// Conn is wrapper on net.Conn with RemoteAddr() override.
|
|
//
|
|
// On first call Read() or RemoteAddr() parse proxyprotocol header and store
|
|
// local and remote addresses.
|
|
type Conn struct {
|
|
net.Conn
|
|
logger Logger
|
|
readBuf *bufio.Reader
|
|
header *Header
|
|
headerErr error
|
|
headerParser HeaderParser
|
|
trustedAddr bool
|
|
once sync.Once
|
|
}
|
|
|
|
// NewConn create wrapper on net.Conn.
|
|
func NewConn(conn net.Conn, headerParser HeaderParser) net.Conn {
|
|
readBuf := bufio.NewReaderSize(conn, bufferSize)
|
|
|
|
return &Conn{
|
|
Conn: conn,
|
|
readBuf: readBuf,
|
|
logger: nil,
|
|
headerParser: headerParser,
|
|
trustedAddr: true,
|
|
}
|
|
}
|
|
|
|
func (conn *Conn) parseHeader() {
|
|
conn.header, conn.headerErr = conn.headerParser.Parse(conn.readBuf)
|
|
if conn.headerErr != nil {
|
|
// conn.logger.Printf("Header parse error: %s", conn.headerErr)
|
|
return
|
|
}
|
|
// conn.logger.Printf("Header parsed %v", conn.header)
|
|
}
|
|
|
|
// Read on first call parse proxyprotocol header.
|
|
//
|
|
// If header parser return error, then error stored and returned. Otherwise call
|
|
// Read on source connection.
|
|
//
|
|
// Following calls of Read function check parse header error.
|
|
// If error not nil, then error returned. Otherwise called source "conn.Read".
|
|
func (conn *Conn) Read(buf []byte) (int, error) {
|
|
conn.once.Do(conn.parseHeader)
|
|
|
|
if conn.headerErr != nil {
|
|
return 0, conn.headerErr
|
|
}
|
|
|
|
return conn.readBuf.Read(buf)
|
|
}
|
|
|
|
// LocalAddr proxy to conn.LocalAddr
|
|
func (conn *Conn) LocalAddr() net.Addr {
|
|
conn.once.Do(conn.parseHeader)
|
|
|
|
if conn.trustedAddr && conn.header != nil {
|
|
return conn.header.DstAddr
|
|
}
|
|
|
|
return conn.Conn.LocalAddr()
|
|
}
|
|
|
|
// RemoteAddr on first call parse proxyprotocol header.
|
|
//
|
|
// If header parser return header, then return source address from header.
|
|
// Otherwise return original source address.
|
|
func (conn *Conn) RemoteAddr() net.Addr {
|
|
conn.once.Do(conn.parseHeader)
|
|
|
|
if conn.trustedAddr && conn.header != nil {
|
|
return conn.header.SrcAddr
|
|
}
|
|
|
|
return conn.Conn.RemoteAddr()
|
|
}
|