fully functional implementation
This commit is contained in:
@@ -0,0 +1,443 @@
|
||||
// udp-bcast-cidr-relay.go
|
||||
// Build (static, no cgo):
|
||||
//
|
||||
// CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build -trimpath -ldflags='-s -w -extldflags "-static"' -o udp-bcast-cidr-relay ./udp-bcast-cidr-relay.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type subnetEntry struct {
|
||||
IfName string
|
||||
IfIdx int
|
||||
Net *net.IPNet // IPv4
|
||||
Bcast net.IP // 4 bytes
|
||||
}
|
||||
|
||||
type recvMeta struct {
|
||||
IfIndex int
|
||||
TTL int
|
||||
}
|
||||
|
||||
func ip4ToU32(ip net.IP) (uint32, bool) {
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil || len(ip4) != 4 {
|
||||
return 0, false
|
||||
}
|
||||
return binary.BigEndian.Uint32(ip.To4()), true
|
||||
}
|
||||
|
||||
func localIPv4SetFromSubnets(subnets []subnetEntry) map[uint32]bool {
|
||||
exists := map[uint32]bool{}
|
||||
for _, s := range subnets {
|
||||
if u, ok := ip4ToU32(s.Net.IP); ok {
|
||||
exists[u] = true
|
||||
}
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
||||
func parseCIDRs(csv string) ([]*net.IPNet, error) {
|
||||
if strings.TrimSpace(csv) == "" {
|
||||
return nil, fmt.Errorf("cidrs required")
|
||||
}
|
||||
var out []*net.IPNet
|
||||
for _, s := range strings.Split(csv, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
_, n, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad cidr %q: %w", s, err)
|
||||
}
|
||||
// normalize to IPv4 nets only
|
||||
ip4 := n.IP.To4()
|
||||
if ip4 == nil || len(n.Mask) != 4 {
|
||||
return nil, fmt.Errorf("cidr must be IPv4: %q", s)
|
||||
}
|
||||
n.IP = ip4
|
||||
out = append(out, n)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func parsePorts(csv string) ([]int, error) {
|
||||
if strings.TrimSpace(csv) == "" {
|
||||
return nil, fmt.Errorf("ports required")
|
||||
}
|
||||
var out []int
|
||||
for _, s := range strings.Split(csv, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
p, err := strconv.Atoi(s)
|
||||
if err != nil || p < 1 || p > 65535 {
|
||||
return nil, fmt.Errorf("bad port %q", s)
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func ipv4Bcast(ip net.IP, mask net.IPMask) (net.IP, bool) {
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil || len(mask) != 4 {
|
||||
return nil, false
|
||||
}
|
||||
b := make(net.IP, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
b[i] = ip4[i] | ^mask[i]
|
||||
}
|
||||
return b, true
|
||||
}
|
||||
|
||||
func netWithinAny(n *net.IPNet, cidrs []*net.IPNet) bool {
|
||||
// treat "within" as: network address is inside CIDR and broadcast is inside CIDR
|
||||
b, ok := ipv4Bcast(n.IP, n.Mask)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, c := range cidrs {
|
||||
if c.Contains(n.IP) && c.Contains(b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func discoverSubnets(cidrs []*net.IPNet, ifacesCSV string) ([]subnetEntry, error) {
|
||||
allow := map[string]bool{}
|
||||
if strings.TrimSpace(ifacesCSV) != "" {
|
||||
for _, s := range strings.Split(ifacesCSV, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
allow[s] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var subs []subnetEntry
|
||||
for _, ifc := range ifaces {
|
||||
if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagLoopback != 0 {
|
||||
continue
|
||||
}
|
||||
if len(allow) > 0 && !allow[ifc.Name] {
|
||||
continue
|
||||
}
|
||||
addrs, err := ifc.Addrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("addrs(%s): %w", ifc.Name, err)
|
||||
}
|
||||
for _, a := range addrs {
|
||||
ipnet, ok := a.(*net.IPNet)
|
||||
if !ok || ipnet.IP == nil {
|
||||
continue
|
||||
}
|
||||
ip4 := ipnet.IP.To4()
|
||||
if ip4 == nil || len(ipnet.Mask) != 4 {
|
||||
continue
|
||||
}
|
||||
n := &net.IPNet{IP: ip4, Mask: ipnet.Mask}
|
||||
if !netWithinAny(n, cidrs) {
|
||||
continue
|
||||
}
|
||||
b, ok := ipv4Bcast(n.IP, n.Mask)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
subs = append(subs, subnetEntry{
|
||||
IfName: ifc.Name,
|
||||
IfIdx: ifc.Index,
|
||||
Net: n,
|
||||
Bcast: b,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// de-dupe (same subnet repeated)
|
||||
seen := map[string]bool{}
|
||||
var out []subnetEntry
|
||||
for _, s := range subs {
|
||||
k := fmt.Sprintf("%d-%s/%d", s.IfIdx, s.Net.IP.String(), maskBits(s.Net.Mask))
|
||||
if !seen[k] {
|
||||
seen[k] = true
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
if out[i].IfName == out[j].IfName {
|
||||
return binary.BigEndian.Uint32(out[i].Bcast) < binary.BigEndian.Uint32(out[j].Bcast)
|
||||
}
|
||||
return out[i].IfName < out[j].IfName
|
||||
})
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func maskBits(m net.IPMask) int {
|
||||
ones, _ := m.Size()
|
||||
return ones
|
||||
}
|
||||
|
||||
func parseOOB(oob []byte) recvMeta {
|
||||
meta := recvMeta{IfIndex: 0, TTL: -1}
|
||||
msgs, err := unix.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return meta
|
||||
}
|
||||
for _, m := range msgs {
|
||||
switch {
|
||||
case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_PKTINFO:
|
||||
// in_pktinfo: ifindex (u32 LE) + spec_dst + addr
|
||||
if len(m.Data) >= 12 {
|
||||
meta.IfIndex = int(binary.LittleEndian.Uint32(m.Data[0:4]))
|
||||
}
|
||||
case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_TTL:
|
||||
switch len(m.Data) {
|
||||
case 1:
|
||||
meta.TTL = int(m.Data[0])
|
||||
case 4:
|
||||
ttlLE := int(binary.LittleEndian.Uint32(m.Data))
|
||||
ttlBE := int(binary.BigEndian.Uint32(m.Data))
|
||||
if ttlLE >= 0 && ttlLE <= 255 {
|
||||
meta.TTL = ttlLE
|
||||
} else {
|
||||
meta.TTL = ttlBE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
type sender struct {
|
||||
ifname string
|
||||
conn net.PacketConn
|
||||
}
|
||||
|
||||
func newSender(ifname string, ttl int) (*sender, error) {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err = unix.SetsockoptString(fd, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifname); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_TTL, ttl); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := os.NewFile(uintptr(fd), "udp-send-"+ifname)
|
||||
pc, err := net.FilePacketConn(f)
|
||||
_ = f.Close()
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
return &sender{ifname: ifname, conn: pc}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
cidrsFlag := flag.String("cidrs", "", "comma-separated IPv4 CIDRs to include (for example 192.168.0.0/16,10.0.0.0/8)")
|
||||
ifacesFlag := flag.String("ifaces", "", "optional comma-separated interface allow-list (empty = all UP non-loopback)")
|
||||
portsFlag := flag.String("ports", "32410,32412,32413,32414,50222", "comma-separated UDP ports to relay")
|
||||
ttlFlag := flag.Int("ttl", 1, "outbound TTL for forwarded packets (use 1)")
|
||||
dropTTLFlag := flag.Int("drop-ttl", 1, "drop inbound packets with this TTL (use 1 to drop your forwarded copies)")
|
||||
timeoutMS := flag.Int("timeout-ms", 250, "read timeout in ms")
|
||||
logLevelFlag := flag.String("log-level", "info", "log level: trace|debug|info|warn|error|fatal|panic")
|
||||
flag.Parse()
|
||||
|
||||
lvl, err := zerolog.ParseLevel(strings.ToLower(strings.TrimSpace(*logLevelFlag)))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARN: bad --log-level %q; using info\n", *logLevelFlag)
|
||||
lvl = zerolog.InfoLevel
|
||||
}
|
||||
lg := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(lvl)
|
||||
|
||||
cidrs, err := parseCIDRs(*cidrsFlag)
|
||||
if err != nil {
|
||||
lg.Fatal().Msgf("cidrs: %v", err)
|
||||
}
|
||||
ports, err := parsePorts(*portsFlag)
|
||||
if err != nil {
|
||||
lg.Fatal().Msgf("ports: %v", err)
|
||||
}
|
||||
|
||||
subnets, err := discoverSubnets(cidrs, *ifacesFlag)
|
||||
if err != nil {
|
||||
lg.Fatal().Msgf("discover: %v", err)
|
||||
}
|
||||
if len(subnets) < 2 {
|
||||
lg.Fatal().Msgf("need at least 2 discovered subnets inside -cidrs; got %d", len(subnets))
|
||||
}
|
||||
|
||||
localIPs := localIPv4SetFromSubnets(subnets)
|
||||
|
||||
// One sender per interface name
|
||||
senders := map[string]*sender{}
|
||||
for _, s := range subnets {
|
||||
if _, ok := senders[s.IfName]; ok {
|
||||
continue
|
||||
}
|
||||
se, err := newSender(s.IfName, *ttlFlag)
|
||||
if err != nil {
|
||||
lg.Fatal().Msgf("sender(%s): %v", s.IfName, err)
|
||||
}
|
||||
senders[s.IfName] = se
|
||||
}
|
||||
defer func() {
|
||||
for _, s := range senders {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
lg.Info().Msgf("Discovered subnets in %s:", *cidrsFlag)
|
||||
for _, s := range subnets {
|
||||
lg.Info().Msgf(" if=%s idx=%d net=%s bcast=%s", s.IfName, s.IfIdx, s.Net.String(), s.Bcast.String())
|
||||
}
|
||||
lg.Info().Msgf("Relaying UDP ports=%v ttl=%d dropTTL=%d", ports, *ttlFlag, *dropTTLFlag)
|
||||
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
for _, p := range ports {
|
||||
port := p
|
||||
go func() {
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||
if err != nil {
|
||||
lg.Fatal().Msgf("socket :%d: %v", port, err)
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
|
||||
lg.Fatal().Msgf("reuseaddr :%d: %v", port, err)
|
||||
}
|
||||
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
|
||||
lg.Fatal().Msgf("broadcast :%d: %v", port, err)
|
||||
}
|
||||
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_PKTINFO, 1); err != nil {
|
||||
lg.Fatal().Msgf("pktinfo :%d: %v", port, err)
|
||||
}
|
||||
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1); err != nil {
|
||||
lg.Fatal().Msgf("recvttl :%d: %v", port, err)
|
||||
}
|
||||
|
||||
sa := &unix.SockaddrInet4{Port: port}
|
||||
copy(sa.Addr[:], net.IPv4zero.To4())
|
||||
if err := unix.Bind(fd, sa); err != nil {
|
||||
lg.Fatal().Msgf("bind :%d: %v", port, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
oob := make([]byte, 256)
|
||||
for {
|
||||
_ = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{
|
||||
Sec: 0,
|
||||
Usec: int64(*timeoutMS) * 1000,
|
||||
})
|
||||
|
||||
n, oobn, _, from, err := unix.Recvmsg(fd, buf, oob, 0)
|
||||
if err != nil {
|
||||
// timeout
|
||||
if errors.Is(err, unix.EAGAIN) ||
|
||||
errors.Is(err, unix.EWOULDBLOCK) ||
|
||||
errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
lg.Warn().Msgf("recvmsg :%d: %v (stopping listener goroutine)", port, err)
|
||||
return
|
||||
}
|
||||
meta := parseOOB(oob[:oobn])
|
||||
if meta.TTL == *dropTTLFlag || meta.IfIndex == 0 {
|
||||
lg.Debug().Msgf("drop :%d: meta ttl=%d ifindex=%d (dropTTL=%d)", port, meta.TTL, meta.IfIndex, *dropTTLFlag)
|
||||
continue
|
||||
}
|
||||
|
||||
var srcIP net.IP
|
||||
switch fa := from.(type) {
|
||||
case *unix.SockaddrInet4:
|
||||
srcIP = net.IPv4(fa.Addr[0], fa.Addr[1], fa.Addr[2], fa.Addr[3]).To4()
|
||||
default:
|
||||
lg.Debug().Msgf("drop :%d: unexpected from addr type %T", port, from)
|
||||
continue
|
||||
}
|
||||
if srcIP == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if TTL is our forwarded TTL, drop.
|
||||
if meta.TTL == *dropTTLFlag && *dropTTLFlag >= 0 {
|
||||
lg.Debug().Msgf("drop :%d: ttl matches dropTTL src=%s ttl=%d", port, srcIP.String(), meta.TTL)
|
||||
continue
|
||||
}
|
||||
|
||||
// If the packet's source IP is one of ours, it's a rebroadcast
|
||||
// we generated. Drop it to avoid looping.
|
||||
if u, ok := ip4ToU32(srcIP); ok && localIPs[u] {
|
||||
lg.Debug().Msgf("drop :%d: looped packet src=%s", port, srcIP.String())
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine which discovered subnet this packet originated from (on the ingress ifindex).
|
||||
origin := map[string]bool{} // key: ifidx+net string
|
||||
for _, sn := range subnets {
|
||||
if sn.IfIdx != meta.IfIndex {
|
||||
continue
|
||||
}
|
||||
if sn.Net.Contains(srcIP) {
|
||||
origin[fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())] = true
|
||||
}
|
||||
}
|
||||
|
||||
payload := buf[:n]
|
||||
|
||||
lg.Debug().Msgf("recv :%d: %d bytes from=%s ifindex=%d ttl=%d", port, n, srcIP.String(), meta.IfIndex, meta.TTL)
|
||||
|
||||
// Fan out to every broadcast target except the origin subnet(s).
|
||||
for _, sn := range subnets {
|
||||
k := fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())
|
||||
if origin[k] {
|
||||
continue
|
||||
}
|
||||
s := senders[sn.IfName]
|
||||
if s == nil {
|
||||
continue
|
||||
}
|
||||
dst := &net.UDPAddr{IP: sn.Bcast, Port: port}
|
||||
_, _ = s.conn.WriteTo(payload, dst)
|
||||
lg.Debug().Msgf("sent :%d: %d bytes to=%s via=%s", port, len(payload), dst.String(), sn.IfName)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
<-stop
|
||||
lg.Info().Msg("Shutting down")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
Reference in New Issue
Block a user