diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..95b1564 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/udp-reflector.iml b/.idea/udp-reflector.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/udp-reflector.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..64cfde9 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/joejulian/udp-reflector + +go 1.25.5 + +require ( + github.com/rs/zerolog v1.34.0 + golang.org/x/sys v0.39.0 +) + +require ( + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e286965 --- /dev/null +++ b/go.sum @@ -0,0 +1,16 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/udp-bcast-cidr-relay b/udp-bcast-cidr-relay new file mode 100755 index 0000000..e34e834 Binary files /dev/null and b/udp-bcast-cidr-relay differ diff --git a/udp-bcast-cidr-relay.go b/udp-bcast-cidr-relay.go new file mode 100644 index 0000000..3e66291 --- /dev/null +++ b/udp-bcast-cidr-relay.go @@ -0,0 +1,443 @@ +// udp-bcast-cidr-relay.go +// Build (static, no cgo): +// +// CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build -trimpath -ldflags='-s -w -extldflags "-static"' -o udp-bcast-cidr-relay ./udp-bcast-cidr-relay.go +package main + +import ( + "encoding/binary" + "errors" + "flag" + "fmt" + "net" + "os" + "os/signal" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/rs/zerolog" + "golang.org/x/sys/unix" +) + +type subnetEntry struct { + IfName string + IfIdx int + Net *net.IPNet // IPv4 + Bcast net.IP // 4 bytes +} + +type recvMeta struct { + IfIndex int + TTL int +} + +func ip4ToU32(ip net.IP) (uint32, bool) { + ip4 := ip.To4() + if ip4 == nil || len(ip4) != 4 { + return 0, false + } + return binary.BigEndian.Uint32(ip.To4()), true +} + +func localIPv4SetFromSubnets(subnets []subnetEntry) map[uint32]bool { + exists := map[uint32]bool{} + for _, s := range subnets { + if u, ok := ip4ToU32(s.Net.IP); ok { + exists[u] = true + } + } + return exists +} + +func parseCIDRs(csv string) ([]*net.IPNet, error) { + if strings.TrimSpace(csv) == "" { + return nil, fmt.Errorf("cidrs required") + } + var out []*net.IPNet + for _, s := range strings.Split(csv, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("bad cidr %q: %w", s, err) + } + // normalize to IPv4 nets only + ip4 := n.IP.To4() + if ip4 == nil || len(n.Mask) != 4 { + return nil, fmt.Errorf("cidr must be IPv4: %q", s) + } + n.IP = ip4 + out = append(out, n) + } + return out, nil +} + +func parsePorts(csv string) ([]int, error) { + if strings.TrimSpace(csv) == "" { + return nil, fmt.Errorf("ports required") + } + var out []int + for _, s := range strings.Split(csv, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + p, err := strconv.Atoi(s) + if err != nil || p < 1 || p > 65535 { + return nil, fmt.Errorf("bad port %q", s) + } + out = append(out, p) + } + return out, nil +} + +func ipv4Bcast(ip net.IP, mask net.IPMask) (net.IP, bool) { + ip4 := ip.To4() + if ip4 == nil || len(mask) != 4 { + return nil, false + } + b := make(net.IP, 4) + for i := 0; i < 4; i++ { + b[i] = ip4[i] | ^mask[i] + } + return b, true +} + +func netWithinAny(n *net.IPNet, cidrs []*net.IPNet) bool { + // treat "within" as: network address is inside CIDR and broadcast is inside CIDR + b, ok := ipv4Bcast(n.IP, n.Mask) + if !ok { + return false + } + for _, c := range cidrs { + if c.Contains(n.IP) && c.Contains(b) { + return true + } + } + return false +} + +func discoverSubnets(cidrs []*net.IPNet, ifacesCSV string) ([]subnetEntry, error) { + allow := map[string]bool{} + if strings.TrimSpace(ifacesCSV) != "" { + for _, s := range strings.Split(ifacesCSV, ",") { + s = strings.TrimSpace(s) + if s != "" { + allow[s] = true + } + } + } + + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + var subs []subnetEntry + for _, ifc := range ifaces { + if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagLoopback != 0 { + continue + } + if len(allow) > 0 && !allow[ifc.Name] { + continue + } + addrs, err := ifc.Addrs() + if err != nil { + return nil, fmt.Errorf("addrs(%s): %w", ifc.Name, err) + } + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + ip4 := ipnet.IP.To4() + if ip4 == nil || len(ipnet.Mask) != 4 { + continue + } + n := &net.IPNet{IP: ip4, Mask: ipnet.Mask} + if !netWithinAny(n, cidrs) { + continue + } + b, ok := ipv4Bcast(n.IP, n.Mask) + if !ok { + continue + } + subs = append(subs, subnetEntry{ + IfName: ifc.Name, + IfIdx: ifc.Index, + Net: n, + Bcast: b, + }) + } + } + + // de-dupe (same subnet repeated) + seen := map[string]bool{} + var out []subnetEntry + for _, s := range subs { + k := fmt.Sprintf("%d-%s/%d", s.IfIdx, s.Net.IP.String(), maskBits(s.Net.Mask)) + if !seen[k] { + seen[k] = true + out = append(out, s) + } + } + sort.Slice(out, func(i, j int) bool { + if out[i].IfName == out[j].IfName { + return binary.BigEndian.Uint32(out[i].Bcast) < binary.BigEndian.Uint32(out[j].Bcast) + } + return out[i].IfName < out[j].IfName + }) + return out, nil +} + +func maskBits(m net.IPMask) int { + ones, _ := m.Size() + return ones +} + +func parseOOB(oob []byte) recvMeta { + meta := recvMeta{IfIndex: 0, TTL: -1} + msgs, err := unix.ParseSocketControlMessage(oob) + if err != nil { + return meta + } + for _, m := range msgs { + switch { + case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_PKTINFO: + // in_pktinfo: ifindex (u32 LE) + spec_dst + addr + if len(m.Data) >= 12 { + meta.IfIndex = int(binary.LittleEndian.Uint32(m.Data[0:4])) + } + case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_TTL: + switch len(m.Data) { + case 1: + meta.TTL = int(m.Data[0]) + case 4: + ttlLE := int(binary.LittleEndian.Uint32(m.Data)) + ttlBE := int(binary.BigEndian.Uint32(m.Data)) + if ttlLE >= 0 && ttlLE <= 255 { + meta.TTL = ttlLE + } else { + meta.TTL = ttlBE + } + } + } + } + return meta +} + +type sender struct { + ifname string + conn net.PacketConn +} + +func newSender(ifname string, ttl int) (*sender, error) { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + if err != nil { + return nil, err + } + if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil { + _ = unix.Close(fd) + return nil, err + } + if err = unix.SetsockoptString(fd, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifname); err != nil { + _ = unix.Close(fd) + return nil, err + } + if err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_TTL, ttl); err != nil { + _ = unix.Close(fd) + return nil, err + } + + f := os.NewFile(uintptr(fd), "udp-send-"+ifname) + pc, err := net.FilePacketConn(f) + _ = f.Close() + if err != nil { + _ = unix.Close(fd) + return nil, err + } + return &sender{ifname: ifname, conn: pc}, nil +} + +func main() { + cidrsFlag := flag.String("cidrs", "", "comma-separated IPv4 CIDRs to include (for example 192.168.0.0/16,10.0.0.0/8)") + ifacesFlag := flag.String("ifaces", "", "optional comma-separated interface allow-list (empty = all UP non-loopback)") + portsFlag := flag.String("ports", "32410,32412,32413,32414,50222", "comma-separated UDP ports to relay") + ttlFlag := flag.Int("ttl", 1, "outbound TTL for forwarded packets (use 1)") + dropTTLFlag := flag.Int("drop-ttl", 1, "drop inbound packets with this TTL (use 1 to drop your forwarded copies)") + timeoutMS := flag.Int("timeout-ms", 250, "read timeout in ms") + logLevelFlag := flag.String("log-level", "info", "log level: trace|debug|info|warn|error|fatal|panic") + flag.Parse() + + lvl, err := zerolog.ParseLevel(strings.ToLower(strings.TrimSpace(*logLevelFlag))) + if err != nil { + fmt.Fprintf(os.Stderr, "WARN: bad --log-level %q; using info\n", *logLevelFlag) + lvl = zerolog.InfoLevel + } + lg := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(lvl) + + cidrs, err := parseCIDRs(*cidrsFlag) + if err != nil { + lg.Fatal().Msgf("cidrs: %v", err) + } + ports, err := parsePorts(*portsFlag) + if err != nil { + lg.Fatal().Msgf("ports: %v", err) + } + + subnets, err := discoverSubnets(cidrs, *ifacesFlag) + if err != nil { + lg.Fatal().Msgf("discover: %v", err) + } + if len(subnets) < 2 { + lg.Fatal().Msgf("need at least 2 discovered subnets inside -cidrs; got %d", len(subnets)) + } + + localIPs := localIPv4SetFromSubnets(subnets) + + // One sender per interface name + senders := map[string]*sender{} + for _, s := range subnets { + if _, ok := senders[s.IfName]; ok { + continue + } + se, err := newSender(s.IfName, *ttlFlag) + if err != nil { + lg.Fatal().Msgf("sender(%s): %v", s.IfName, err) + } + senders[s.IfName] = se + } + defer func() { + for _, s := range senders { + _ = s.conn.Close() + } + }() + + lg.Info().Msgf("Discovered subnets in %s:", *cidrsFlag) + for _, s := range subnets { + lg.Info().Msgf(" if=%s idx=%d net=%s bcast=%s", s.IfName, s.IfIdx, s.Net.String(), s.Bcast.String()) + } + lg.Info().Msgf("Relaying UDP ports=%v ttl=%d dropTTL=%d", ports, *ttlFlag, *dropTTLFlag) + + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + + for _, p := range ports { + port := p + go func() { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + if err != nil { + lg.Fatal().Msgf("socket :%d: %v", port, err) + } + defer unix.Close(fd) + + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { + lg.Fatal().Msgf("reuseaddr :%d: %v", port, err) + } + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil { + lg.Fatal().Msgf("broadcast :%d: %v", port, err) + } + if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_PKTINFO, 1); err != nil { + lg.Fatal().Msgf("pktinfo :%d: %v", port, err) + } + if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1); err != nil { + lg.Fatal().Msgf("recvttl :%d: %v", port, err) + } + + sa := &unix.SockaddrInet4{Port: port} + copy(sa.Addr[:], net.IPv4zero.To4()) + if err := unix.Bind(fd, sa); err != nil { + lg.Fatal().Msgf("bind :%d: %v", port, err) + } + + buf := make([]byte, 65535) + oob := make([]byte, 256) + for { + _ = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{ + Sec: 0, + Usec: int64(*timeoutMS) * 1000, + }) + + n, oobn, _, from, err := unix.Recvmsg(fd, buf, oob, 0) + if err != nil { + // timeout + if errors.Is(err, unix.EAGAIN) || + errors.Is(err, unix.EWOULDBLOCK) || + errors.Is(err, unix.EINTR) { + continue + } + lg.Warn().Msgf("recvmsg :%d: %v (stopping listener goroutine)", port, err) + return + } + meta := parseOOB(oob[:oobn]) + if meta.TTL == *dropTTLFlag || meta.IfIndex == 0 { + lg.Debug().Msgf("drop :%d: meta ttl=%d ifindex=%d (dropTTL=%d)", port, meta.TTL, meta.IfIndex, *dropTTLFlag) + continue + } + + var srcIP net.IP + switch fa := from.(type) { + case *unix.SockaddrInet4: + srcIP = net.IPv4(fa.Addr[0], fa.Addr[1], fa.Addr[2], fa.Addr[3]).To4() + default: + lg.Debug().Msgf("drop :%d: unexpected from addr type %T", port, from) + continue + } + if srcIP == nil { + continue + } + + // if TTL is our forwarded TTL, drop. + if meta.TTL == *dropTTLFlag && *dropTTLFlag >= 0 { + lg.Debug().Msgf("drop :%d: ttl matches dropTTL src=%s ttl=%d", port, srcIP.String(), meta.TTL) + continue + } + + // If the packet's source IP is one of ours, it's a rebroadcast + // we generated. Drop it to avoid looping. + if u, ok := ip4ToU32(srcIP); ok && localIPs[u] { + lg.Debug().Msgf("drop :%d: looped packet src=%s", port, srcIP.String()) + continue + } + + // Determine which discovered subnet this packet originated from (on the ingress ifindex). + origin := map[string]bool{} // key: ifidx+net string + for _, sn := range subnets { + if sn.IfIdx != meta.IfIndex { + continue + } + if sn.Net.Contains(srcIP) { + origin[fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())] = true + } + } + + payload := buf[:n] + + lg.Debug().Msgf("recv :%d: %d bytes from=%s ifindex=%d ttl=%d", port, n, srcIP.String(), meta.IfIndex, meta.TTL) + + // Fan out to every broadcast target except the origin subnet(s). + for _, sn := range subnets { + k := fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String()) + if origin[k] { + continue + } + s := senders[sn.IfName] + if s == nil { + continue + } + dst := &net.UDPAddr{IP: sn.Bcast, Port: port} + _, _ = s.conn.WriteTo(payload, dst) + lg.Debug().Msgf("sent :%d: %d bytes to=%s via=%s", port, len(payload), dst.String(), sn.IfName) + } + } + }() + } + + <-stop + lg.Info().Msg("Shutting down") + time.Sleep(100 * time.Millisecond) +}