// udp-bcast-cidr-relay.go // Build (static, no cgo): // // CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build -trimpath -ldflags='-s -w -extldflags "-static"' -o udp-bcast-cidr-relay ./udp-bcast-cidr-relay.go package main import ( "encoding/binary" "errors" "flag" "fmt" "net" "os" "os/signal" "sort" "strconv" "strings" "syscall" "time" "github.com/rs/zerolog" "golang.org/x/sys/unix" ) type subnetEntry struct { IfName string IfIdx int Net *net.IPNet // IPv4 Bcast net.IP // 4 bytes } type recvMeta struct { IfIndex int TTL int } func ip4ToU32(ip net.IP) (uint32, bool) { ip4 := ip.To4() if ip4 == nil || len(ip4) != 4 { return 0, false } return binary.BigEndian.Uint32(ip.To4()), true } func localIPv4SetFromSubnets(subnets []subnetEntry) map[uint32]bool { exists := map[uint32]bool{} for _, s := range subnets { if u, ok := ip4ToU32(s.Net.IP); ok { exists[u] = true } } return exists } func parseCIDRs(csv string) ([]*net.IPNet, error) { if strings.TrimSpace(csv) == "" { return nil, fmt.Errorf("cidrs required") } var out []*net.IPNet for _, s := range strings.Split(csv, ",") { s = strings.TrimSpace(s) if s == "" { continue } _, n, err := net.ParseCIDR(s) if err != nil { return nil, fmt.Errorf("bad cidr %q: %w", s, err) } // normalize to IPv4 nets only ip4 := n.IP.To4() if ip4 == nil || len(n.Mask) != 4 { return nil, fmt.Errorf("cidr must be IPv4: %q", s) } n.IP = ip4 out = append(out, n) } return out, nil } func parsePorts(csv string) ([]int, error) { if strings.TrimSpace(csv) == "" { return nil, fmt.Errorf("ports required") } var out []int for _, s := range strings.Split(csv, ",") { s = strings.TrimSpace(s) if s == "" { continue } p, err := strconv.Atoi(s) if err != nil || p < 1 || p > 65535 { return nil, fmt.Errorf("bad port %q", s) } out = append(out, p) } return out, nil } func ipv4Bcast(ip net.IP, mask net.IPMask) (net.IP, bool) { ip4 := ip.To4() if ip4 == nil || len(mask) != 4 { return nil, false } b := make(net.IP, 4) for i := 0; i < 4; i++ { b[i] = ip4[i] | ^mask[i] } return b, true } func netWithinAny(n *net.IPNet, cidrs []*net.IPNet) bool { // treat "within" as: network address is inside CIDR and broadcast is inside CIDR b, ok := ipv4Bcast(n.IP, n.Mask) if !ok { return false } for _, c := range cidrs { if c.Contains(n.IP) && c.Contains(b) { return true } } return false } func discoverSubnets(cidrs []*net.IPNet, ifacesCSV string) ([]subnetEntry, error) { allow := map[string]bool{} if strings.TrimSpace(ifacesCSV) != "" { for _, s := range strings.Split(ifacesCSV, ",") { s = strings.TrimSpace(s) if s != "" { allow[s] = true } } } ifaces, err := net.Interfaces() if err != nil { return nil, err } var subs []subnetEntry for _, ifc := range ifaces { if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagLoopback != 0 { continue } if len(allow) > 0 && !allow[ifc.Name] { continue } addrs, err := ifc.Addrs() if err != nil { return nil, fmt.Errorf("addrs(%s): %w", ifc.Name, err) } for _, a := range addrs { ipnet, ok := a.(*net.IPNet) if !ok || ipnet.IP == nil { continue } ip4 := ipnet.IP.To4() if ip4 == nil || len(ipnet.Mask) != 4 { continue } n := &net.IPNet{IP: ip4, Mask: ipnet.Mask} if !netWithinAny(n, cidrs) { continue } b, ok := ipv4Bcast(n.IP, n.Mask) if !ok { continue } subs = append(subs, subnetEntry{ IfName: ifc.Name, IfIdx: ifc.Index, Net: n, Bcast: b, }) } } // de-dupe (same subnet repeated) seen := map[string]bool{} var out []subnetEntry for _, s := range subs { k := fmt.Sprintf("%d-%s/%d", s.IfIdx, s.Net.IP.String(), maskBits(s.Net.Mask)) if !seen[k] { seen[k] = true out = append(out, s) } } sort.Slice(out, func(i, j int) bool { if out[i].IfName == out[j].IfName { return binary.BigEndian.Uint32(out[i].Bcast) < binary.BigEndian.Uint32(out[j].Bcast) } return out[i].IfName < out[j].IfName }) return out, nil } func maskBits(m net.IPMask) int { ones, _ := m.Size() return ones } func parseOOB(oob []byte) recvMeta { meta := recvMeta{IfIndex: 0, TTL: -1} msgs, err := unix.ParseSocketControlMessage(oob) if err != nil { return meta } for _, m := range msgs { switch { case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_PKTINFO: // in_pktinfo: ifindex (u32 LE) + spec_dst + addr if len(m.Data) >= 12 { meta.IfIndex = int(binary.LittleEndian.Uint32(m.Data[0:4])) } case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_TTL: switch len(m.Data) { case 1: meta.TTL = int(m.Data[0]) case 4: ttlLE := int(binary.LittleEndian.Uint32(m.Data)) ttlBE := int(binary.BigEndian.Uint32(m.Data)) if ttlLE >= 0 && ttlLE <= 255 { meta.TTL = ttlLE } else { meta.TTL = ttlBE } } } } return meta } type sender struct { ifname string conn net.PacketConn } func newSender(ifname string, ttl int) (*sender, error) { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err != nil { return nil, err } if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil { _ = unix.Close(fd) return nil, err } if err = unix.SetsockoptString(fd, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifname); err != nil { _ = unix.Close(fd) return nil, err } if err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_TTL, ttl); err != nil { _ = unix.Close(fd) return nil, err } f := os.NewFile(uintptr(fd), "udp-send-"+ifname) pc, err := net.FilePacketConn(f) _ = f.Close() if err != nil { _ = unix.Close(fd) return nil, err } return &sender{ifname: ifname, conn: pc}, nil } func main() { cidrsFlag := flag.String("cidrs", "", "comma-separated IPv4 CIDRs to include (for example 192.168.0.0/16,10.0.0.0/8)") ifacesFlag := flag.String("ifaces", "", "optional comma-separated interface allow-list (empty = all UP non-loopback)") portsFlag := flag.String("ports", "32410,32412,32413,32414,50222", "comma-separated UDP ports to relay") ttlFlag := flag.Int("ttl", 1, "outbound TTL for forwarded packets (use 1)") dropTTLFlag := flag.Int("drop-ttl", 1, "drop inbound packets with this TTL (use 1 to drop your forwarded copies)") timeoutMS := flag.Int("timeout-ms", 250, "read timeout in ms") logLevelFlag := flag.String("log-level", "info", "log level: trace|debug|info|warn|error|fatal|panic") flag.Parse() lvl, err := zerolog.ParseLevel(strings.ToLower(strings.TrimSpace(*logLevelFlag))) if err != nil { fmt.Fprintf(os.Stderr, "WARN: bad --log-level %q; using info\n", *logLevelFlag) lvl = zerolog.InfoLevel } lg := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(lvl) cidrs, err := parseCIDRs(*cidrsFlag) if err != nil { lg.Fatal().Msgf("cidrs: %v", err) } ports, err := parsePorts(*portsFlag) if err != nil { lg.Fatal().Msgf("ports: %v", err) } subnets, err := discoverSubnets(cidrs, *ifacesFlag) if err != nil { lg.Fatal().Msgf("discover: %v", err) } if len(subnets) < 2 { lg.Fatal().Msgf("need at least 2 discovered subnets inside -cidrs; got %d", len(subnets)) } localIPs := localIPv4SetFromSubnets(subnets) // One sender per interface name senders := map[string]*sender{} for _, s := range subnets { if _, ok := senders[s.IfName]; ok { continue } se, err := newSender(s.IfName, *ttlFlag) if err != nil { lg.Fatal().Msgf("sender(%s): %v", s.IfName, err) } senders[s.IfName] = se } defer func() { for _, s := range senders { _ = s.conn.Close() } }() lg.Info().Msgf("Discovered subnets in %s:", *cidrsFlag) for _, s := range subnets { lg.Info().Msgf(" if=%s idx=%d net=%s bcast=%s", s.IfName, s.IfIdx, s.Net.String(), s.Bcast.String()) } lg.Info().Msgf("Relaying UDP ports=%v ttl=%d dropTTL=%d", ports, *ttlFlag, *dropTTLFlag) stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) for _, p := range ports { port := p go func() { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err != nil { lg.Fatal().Msgf("socket :%d: %v", port, err) } defer unix.Close(fd) if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { lg.Fatal().Msgf("reuseaddr :%d: %v", port, err) } if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil { lg.Fatal().Msgf("broadcast :%d: %v", port, err) } if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_PKTINFO, 1); err != nil { lg.Fatal().Msgf("pktinfo :%d: %v", port, err) } if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1); err != nil { lg.Fatal().Msgf("recvttl :%d: %v", port, err) } sa := &unix.SockaddrInet4{Port: port} copy(sa.Addr[:], net.IPv4zero.To4()) if err := unix.Bind(fd, sa); err != nil { lg.Fatal().Msgf("bind :%d: %v", port, err) } buf := make([]byte, 65535) oob := make([]byte, 256) for { _ = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{ Sec: 0, Usec: int64(*timeoutMS) * 1000, }) n, oobn, _, from, err := unix.Recvmsg(fd, buf, oob, 0) if err != nil { // timeout if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, unix.EINTR) { continue } lg.Warn().Msgf("recvmsg :%d: %v (stopping listener goroutine)", port, err) return } meta := parseOOB(oob[:oobn]) if meta.TTL == *dropTTLFlag || meta.IfIndex == 0 { lg.Debug().Msgf("drop :%d: meta ttl=%d ifindex=%d (dropTTL=%d)", port, meta.TTL, meta.IfIndex, *dropTTLFlag) continue } var srcIP net.IP switch fa := from.(type) { case *unix.SockaddrInet4: srcIP = net.IPv4(fa.Addr[0], fa.Addr[1], fa.Addr[2], fa.Addr[3]).To4() default: lg.Debug().Msgf("drop :%d: unexpected from addr type %T", port, from) continue } if srcIP == nil { continue } // if TTL is our forwarded TTL, drop. if meta.TTL == *dropTTLFlag && *dropTTLFlag >= 0 { lg.Debug().Msgf("drop :%d: ttl matches dropTTL src=%s ttl=%d", port, srcIP.String(), meta.TTL) continue } // If the packet's source IP is one of ours, it's a rebroadcast // we generated. Drop it to avoid looping. if u, ok := ip4ToU32(srcIP); ok && localIPs[u] { lg.Debug().Msgf("drop :%d: looped packet src=%s", port, srcIP.String()) continue } // Determine which discovered subnet this packet originated from (on the ingress ifindex). origin := map[string]bool{} // key: ifidx+net string for _, sn := range subnets { if sn.IfIdx != meta.IfIndex { continue } if sn.Net.Contains(srcIP) { origin[fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())] = true } } payload := buf[:n] lg.Debug().Msgf("recv :%d: %d bytes from=%s ifindex=%d ttl=%d", port, n, srcIP.String(), meta.IfIndex, meta.TTL) // Fan out to every broadcast target except the origin subnet(s). for _, sn := range subnets { k := fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String()) if origin[k] { continue } s := senders[sn.IfName] if s == nil { continue } dst := &net.UDPAddr{IP: sn.Bcast, Port: port} _, _ = s.conn.WriteTo(payload, dst) lg.Debug().Msgf("sent :%d: %d bytes to=%s via=%s", port, len(payload), dst.String(), sn.IfName) } } }() } <-stop lg.Info().Msg("Shutting down") time.Sleep(100 * time.Millisecond) }