fully functional implementation

This commit is contained in:
Joe Julian
2026-02-05 10:43:40 -08:00
parent 99f1ebeaf2
commit bc092dd706
8 changed files with 503 additions and 0 deletions
+8
View File
@@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/udp-reflector.iml" filepath="$PROJECT_DIR$/.idea/udp-reflector.iml" />
</modules>
</component>
</project>
+9
View File
@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
Generated
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>
+13
View File
@@ -0,0 +1,13 @@
module github.com/joejulian/udp-reflector
go 1.25.5
require (
github.com/rs/zerolog v1.34.0
golang.org/x/sys v0.39.0
)
require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
)
+16
View File
@@ -0,0 +1,16 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
BIN
View File
Binary file not shown.
+443
View File
@@ -0,0 +1,443 @@
// udp-bcast-cidr-relay.go
// Build (static, no cgo):
//
// CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build -trimpath -ldflags='-s -w -extldflags "-static"' -o udp-bcast-cidr-relay ./udp-bcast-cidr-relay.go
package main
import (
"encoding/binary"
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/rs/zerolog"
"golang.org/x/sys/unix"
)
type subnetEntry struct {
IfName string
IfIdx int
Net *net.IPNet // IPv4
Bcast net.IP // 4 bytes
}
type recvMeta struct {
IfIndex int
TTL int
}
func ip4ToU32(ip net.IP) (uint32, bool) {
ip4 := ip.To4()
if ip4 == nil || len(ip4) != 4 {
return 0, false
}
return binary.BigEndian.Uint32(ip.To4()), true
}
func localIPv4SetFromSubnets(subnets []subnetEntry) map[uint32]bool {
exists := map[uint32]bool{}
for _, s := range subnets {
if u, ok := ip4ToU32(s.Net.IP); ok {
exists[u] = true
}
}
return exists
}
func parseCIDRs(csv string) ([]*net.IPNet, error) {
if strings.TrimSpace(csv) == "" {
return nil, fmt.Errorf("cidrs required")
}
var out []*net.IPNet
for _, s := range strings.Split(csv, ",") {
s = strings.TrimSpace(s)
if s == "" {
continue
}
_, n, err := net.ParseCIDR(s)
if err != nil {
return nil, fmt.Errorf("bad cidr %q: %w", s, err)
}
// normalize to IPv4 nets only
ip4 := n.IP.To4()
if ip4 == nil || len(n.Mask) != 4 {
return nil, fmt.Errorf("cidr must be IPv4: %q", s)
}
n.IP = ip4
out = append(out, n)
}
return out, nil
}
func parsePorts(csv string) ([]int, error) {
if strings.TrimSpace(csv) == "" {
return nil, fmt.Errorf("ports required")
}
var out []int
for _, s := range strings.Split(csv, ",") {
s = strings.TrimSpace(s)
if s == "" {
continue
}
p, err := strconv.Atoi(s)
if err != nil || p < 1 || p > 65535 {
return nil, fmt.Errorf("bad port %q", s)
}
out = append(out, p)
}
return out, nil
}
func ipv4Bcast(ip net.IP, mask net.IPMask) (net.IP, bool) {
ip4 := ip.To4()
if ip4 == nil || len(mask) != 4 {
return nil, false
}
b := make(net.IP, 4)
for i := 0; i < 4; i++ {
b[i] = ip4[i] | ^mask[i]
}
return b, true
}
func netWithinAny(n *net.IPNet, cidrs []*net.IPNet) bool {
// treat "within" as: network address is inside CIDR and broadcast is inside CIDR
b, ok := ipv4Bcast(n.IP, n.Mask)
if !ok {
return false
}
for _, c := range cidrs {
if c.Contains(n.IP) && c.Contains(b) {
return true
}
}
return false
}
func discoverSubnets(cidrs []*net.IPNet, ifacesCSV string) ([]subnetEntry, error) {
allow := map[string]bool{}
if strings.TrimSpace(ifacesCSV) != "" {
for _, s := range strings.Split(ifacesCSV, ",") {
s = strings.TrimSpace(s)
if s != "" {
allow[s] = true
}
}
}
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var subs []subnetEntry
for _, ifc := range ifaces {
if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagLoopback != 0 {
continue
}
if len(allow) > 0 && !allow[ifc.Name] {
continue
}
addrs, err := ifc.Addrs()
if err != nil {
return nil, fmt.Errorf("addrs(%s): %w", ifc.Name, err)
}
for _, a := range addrs {
ipnet, ok := a.(*net.IPNet)
if !ok || ipnet.IP == nil {
continue
}
ip4 := ipnet.IP.To4()
if ip4 == nil || len(ipnet.Mask) != 4 {
continue
}
n := &net.IPNet{IP: ip4, Mask: ipnet.Mask}
if !netWithinAny(n, cidrs) {
continue
}
b, ok := ipv4Bcast(n.IP, n.Mask)
if !ok {
continue
}
subs = append(subs, subnetEntry{
IfName: ifc.Name,
IfIdx: ifc.Index,
Net: n,
Bcast: b,
})
}
}
// de-dupe (same subnet repeated)
seen := map[string]bool{}
var out []subnetEntry
for _, s := range subs {
k := fmt.Sprintf("%d-%s/%d", s.IfIdx, s.Net.IP.String(), maskBits(s.Net.Mask))
if !seen[k] {
seen[k] = true
out = append(out, s)
}
}
sort.Slice(out, func(i, j int) bool {
if out[i].IfName == out[j].IfName {
return binary.BigEndian.Uint32(out[i].Bcast) < binary.BigEndian.Uint32(out[j].Bcast)
}
return out[i].IfName < out[j].IfName
})
return out, nil
}
func maskBits(m net.IPMask) int {
ones, _ := m.Size()
return ones
}
func parseOOB(oob []byte) recvMeta {
meta := recvMeta{IfIndex: 0, TTL: -1}
msgs, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return meta
}
for _, m := range msgs {
switch {
case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_PKTINFO:
// in_pktinfo: ifindex (u32 LE) + spec_dst + addr
if len(m.Data) >= 12 {
meta.IfIndex = int(binary.LittleEndian.Uint32(m.Data[0:4]))
}
case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_TTL:
switch len(m.Data) {
case 1:
meta.TTL = int(m.Data[0])
case 4:
ttlLE := int(binary.LittleEndian.Uint32(m.Data))
ttlBE := int(binary.BigEndian.Uint32(m.Data))
if ttlLE >= 0 && ttlLE <= 255 {
meta.TTL = ttlLE
} else {
meta.TTL = ttlBE
}
}
}
}
return meta
}
type sender struct {
ifname string
conn net.PacketConn
}
func newSender(ifname string, ttl int) (*sender, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err != nil {
return nil, err
}
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
_ = unix.Close(fd)
return nil, err
}
if err = unix.SetsockoptString(fd, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifname); err != nil {
_ = unix.Close(fd)
return nil, err
}
if err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_TTL, ttl); err != nil {
_ = unix.Close(fd)
return nil, err
}
f := os.NewFile(uintptr(fd), "udp-send-"+ifname)
pc, err := net.FilePacketConn(f)
_ = f.Close()
if err != nil {
_ = unix.Close(fd)
return nil, err
}
return &sender{ifname: ifname, conn: pc}, nil
}
func main() {
cidrsFlag := flag.String("cidrs", "", "comma-separated IPv4 CIDRs to include (for example 192.168.0.0/16,10.0.0.0/8)")
ifacesFlag := flag.String("ifaces", "", "optional comma-separated interface allow-list (empty = all UP non-loopback)")
portsFlag := flag.String("ports", "32410,32412,32413,32414,50222", "comma-separated UDP ports to relay")
ttlFlag := flag.Int("ttl", 1, "outbound TTL for forwarded packets (use 1)")
dropTTLFlag := flag.Int("drop-ttl", 1, "drop inbound packets with this TTL (use 1 to drop your forwarded copies)")
timeoutMS := flag.Int("timeout-ms", 250, "read timeout in ms")
logLevelFlag := flag.String("log-level", "info", "log level: trace|debug|info|warn|error|fatal|panic")
flag.Parse()
lvl, err := zerolog.ParseLevel(strings.ToLower(strings.TrimSpace(*logLevelFlag)))
if err != nil {
fmt.Fprintf(os.Stderr, "WARN: bad --log-level %q; using info\n", *logLevelFlag)
lvl = zerolog.InfoLevel
}
lg := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(lvl)
cidrs, err := parseCIDRs(*cidrsFlag)
if err != nil {
lg.Fatal().Msgf("cidrs: %v", err)
}
ports, err := parsePorts(*portsFlag)
if err != nil {
lg.Fatal().Msgf("ports: %v", err)
}
subnets, err := discoverSubnets(cidrs, *ifacesFlag)
if err != nil {
lg.Fatal().Msgf("discover: %v", err)
}
if len(subnets) < 2 {
lg.Fatal().Msgf("need at least 2 discovered subnets inside -cidrs; got %d", len(subnets))
}
localIPs := localIPv4SetFromSubnets(subnets)
// One sender per interface name
senders := map[string]*sender{}
for _, s := range subnets {
if _, ok := senders[s.IfName]; ok {
continue
}
se, err := newSender(s.IfName, *ttlFlag)
if err != nil {
lg.Fatal().Msgf("sender(%s): %v", s.IfName, err)
}
senders[s.IfName] = se
}
defer func() {
for _, s := range senders {
_ = s.conn.Close()
}
}()
lg.Info().Msgf("Discovered subnets in %s:", *cidrsFlag)
for _, s := range subnets {
lg.Info().Msgf(" if=%s idx=%d net=%s bcast=%s", s.IfName, s.IfIdx, s.Net.String(), s.Bcast.String())
}
lg.Info().Msgf("Relaying UDP ports=%v ttl=%d dropTTL=%d", ports, *ttlFlag, *dropTTLFlag)
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
for _, p := range ports {
port := p
go func() {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err != nil {
lg.Fatal().Msgf("socket :%d: %v", port, err)
}
defer unix.Close(fd)
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
lg.Fatal().Msgf("reuseaddr :%d: %v", port, err)
}
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
lg.Fatal().Msgf("broadcast :%d: %v", port, err)
}
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_PKTINFO, 1); err != nil {
lg.Fatal().Msgf("pktinfo :%d: %v", port, err)
}
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1); err != nil {
lg.Fatal().Msgf("recvttl :%d: %v", port, err)
}
sa := &unix.SockaddrInet4{Port: port}
copy(sa.Addr[:], net.IPv4zero.To4())
if err := unix.Bind(fd, sa); err != nil {
lg.Fatal().Msgf("bind :%d: %v", port, err)
}
buf := make([]byte, 65535)
oob := make([]byte, 256)
for {
_ = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{
Sec: 0,
Usec: int64(*timeoutMS) * 1000,
})
n, oobn, _, from, err := unix.Recvmsg(fd, buf, oob, 0)
if err != nil {
// timeout
if errors.Is(err, unix.EAGAIN) ||
errors.Is(err, unix.EWOULDBLOCK) ||
errors.Is(err, unix.EINTR) {
continue
}
lg.Warn().Msgf("recvmsg :%d: %v (stopping listener goroutine)", port, err)
return
}
meta := parseOOB(oob[:oobn])
if meta.TTL == *dropTTLFlag || meta.IfIndex == 0 {
lg.Debug().Msgf("drop :%d: meta ttl=%d ifindex=%d (dropTTL=%d)", port, meta.TTL, meta.IfIndex, *dropTTLFlag)
continue
}
var srcIP net.IP
switch fa := from.(type) {
case *unix.SockaddrInet4:
srcIP = net.IPv4(fa.Addr[0], fa.Addr[1], fa.Addr[2], fa.Addr[3]).To4()
default:
lg.Debug().Msgf("drop :%d: unexpected from addr type %T", port, from)
continue
}
if srcIP == nil {
continue
}
// if TTL is our forwarded TTL, drop.
if meta.TTL == *dropTTLFlag && *dropTTLFlag >= 0 {
lg.Debug().Msgf("drop :%d: ttl matches dropTTL src=%s ttl=%d", port, srcIP.String(), meta.TTL)
continue
}
// If the packet's source IP is one of ours, it's a rebroadcast
// we generated. Drop it to avoid looping.
if u, ok := ip4ToU32(srcIP); ok && localIPs[u] {
lg.Debug().Msgf("drop :%d: looped packet src=%s", port, srcIP.String())
continue
}
// Determine which discovered subnet this packet originated from (on the ingress ifindex).
origin := map[string]bool{} // key: ifidx+net string
for _, sn := range subnets {
if sn.IfIdx != meta.IfIndex {
continue
}
if sn.Net.Contains(srcIP) {
origin[fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())] = true
}
}
payload := buf[:n]
lg.Debug().Msgf("recv :%d: %d bytes from=%s ifindex=%d ttl=%d", port, n, srcIP.String(), meta.IfIndex, meta.TTL)
// Fan out to every broadcast target except the origin subnet(s).
for _, sn := range subnets {
k := fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())
if origin[k] {
continue
}
s := senders[sn.IfName]
if s == nil {
continue
}
dst := &net.UDPAddr{IP: sn.Bcast, Port: port}
_, _ = s.conn.WriteTo(payload, dst)
lg.Debug().Msgf("sent :%d: %d bytes to=%s via=%s", port, len(payload), dst.String(), sn.IfName)
}
}
}()
}
<-stop
lg.Info().Msg("Shutting down")
time.Sleep(100 * time.Millisecond)
}