Files
go-proxyprotocol/conn.go
2025-12-10 07:46:10 +02:00

87 lines
2.0 KiB
Go

package proxyprotocol
import (
"bufio"
"net"
"sync"
)
// Conn is wrapper on net.Conn with RemoteAddr() override.
//
// On first call Read() or RemoteAddr() parse proxyprotocol header and store
// local and remote addresses.
type Conn struct {
net.Conn
logger Logger
readBuf *bufio.Reader
header *Header
headerErr error
headerParser HeaderParser
trustedAddr bool
once sync.Once
}
// NewConn create wrapper on net.Conn.
func NewConn(conn net.Conn, headerParser HeaderParser) net.Conn {
readBuf := bufio.NewReaderSize(conn, bufferSize)
return &Conn{
Conn: conn,
readBuf: readBuf,
logger: nil,
headerParser: headerParser,
trustedAddr: true,
}
}
func (conn *Conn) parseHeader() {
conn.header, conn.headerErr = conn.headerParser.Parse(conn.readBuf)
// if conn.headerErr != nil {
// conn.logger.Printf("Header parse error: %s", conn.headerErr)
// return
// }
// conn.logger.Printf("Header parsed %v", conn.header)
}
// Read on first call parse proxyprotocol header.
//
// If header parser return error, then error stored and returned. Otherwise call
// Read on source connection.
//
// Following calls of Read function check parse header error.
// If error not nil, then error returned. Otherwise called source "conn.Read".
func (conn *Conn) Read(buf []byte) (int, error) {
conn.once.Do(conn.parseHeader)
if conn.headerErr != nil {
return 0, conn.headerErr
}
return conn.readBuf.Read(buf)
}
// LocalAddr proxy to conn.LocalAddr
func (conn *Conn) LocalAddr() net.Addr {
conn.once.Do(conn.parseHeader)
if conn.trustedAddr && conn.header != nil {
return conn.header.DstAddr
}
return conn.Conn.LocalAddr()
}
// RemoteAddr on first call parse proxyprotocol header.
//
// If header parser return header, then return source address from header.
// Otherwise return original source address.
func (conn *Conn) RemoteAddr() net.Addr {
conn.once.Do(conn.parseHeader)
if conn.trustedAddr && conn.header != nil {
return conn.header.SrcAddr
}
return conn.Conn.RemoteAddr()
}