444 lines
11 KiB
Go
444 lines
11 KiB
Go
// udp-bcast-cidr-relay.go
|
|
// Build (static, no cgo):
|
|
//
|
|
// CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build -trimpath -ldflags='-s -w -extldflags "-static"' -o udp-bcast-cidr-relay ./udp-bcast-cidr-relay.go
|
|
package main
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
type subnetEntry struct {
|
|
IfName string
|
|
IfIdx int
|
|
Net *net.IPNet // IPv4
|
|
Bcast net.IP // 4 bytes
|
|
}
|
|
|
|
type recvMeta struct {
|
|
IfIndex int
|
|
TTL int
|
|
}
|
|
|
|
func ip4ToU32(ip net.IP) (uint32, bool) {
|
|
ip4 := ip.To4()
|
|
if ip4 == nil || len(ip4) != 4 {
|
|
return 0, false
|
|
}
|
|
return binary.BigEndian.Uint32(ip.To4()), true
|
|
}
|
|
|
|
func localIPv4SetFromSubnets(subnets []subnetEntry) map[uint32]bool {
|
|
exists := map[uint32]bool{}
|
|
for _, s := range subnets {
|
|
if u, ok := ip4ToU32(s.Net.IP); ok {
|
|
exists[u] = true
|
|
}
|
|
}
|
|
return exists
|
|
}
|
|
|
|
func parseCIDRs(csv string) ([]*net.IPNet, error) {
|
|
if strings.TrimSpace(csv) == "" {
|
|
return nil, fmt.Errorf("cidrs required")
|
|
}
|
|
var out []*net.IPNet
|
|
for _, s := range strings.Split(csv, ",") {
|
|
s = strings.TrimSpace(s)
|
|
if s == "" {
|
|
continue
|
|
}
|
|
_, n, err := net.ParseCIDR(s)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("bad cidr %q: %w", s, err)
|
|
}
|
|
// normalize to IPv4 nets only
|
|
ip4 := n.IP.To4()
|
|
if ip4 == nil || len(n.Mask) != 4 {
|
|
return nil, fmt.Errorf("cidr must be IPv4: %q", s)
|
|
}
|
|
n.IP = ip4
|
|
out = append(out, n)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func parsePorts(csv string) ([]int, error) {
|
|
if strings.TrimSpace(csv) == "" {
|
|
return nil, fmt.Errorf("ports required")
|
|
}
|
|
var out []int
|
|
for _, s := range strings.Split(csv, ",") {
|
|
s = strings.TrimSpace(s)
|
|
if s == "" {
|
|
continue
|
|
}
|
|
p, err := strconv.Atoi(s)
|
|
if err != nil || p < 1 || p > 65535 {
|
|
return nil, fmt.Errorf("bad port %q", s)
|
|
}
|
|
out = append(out, p)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func ipv4Bcast(ip net.IP, mask net.IPMask) (net.IP, bool) {
|
|
ip4 := ip.To4()
|
|
if ip4 == nil || len(mask) != 4 {
|
|
return nil, false
|
|
}
|
|
b := make(net.IP, 4)
|
|
for i := 0; i < 4; i++ {
|
|
b[i] = ip4[i] | ^mask[i]
|
|
}
|
|
return b, true
|
|
}
|
|
|
|
func netWithinAny(n *net.IPNet, cidrs []*net.IPNet) bool {
|
|
// treat "within" as: network address is inside CIDR and broadcast is inside CIDR
|
|
b, ok := ipv4Bcast(n.IP, n.Mask)
|
|
if !ok {
|
|
return false
|
|
}
|
|
for _, c := range cidrs {
|
|
if c.Contains(n.IP) && c.Contains(b) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func discoverSubnets(cidrs []*net.IPNet, ifacesCSV string) ([]subnetEntry, error) {
|
|
allow := map[string]bool{}
|
|
if strings.TrimSpace(ifacesCSV) != "" {
|
|
for _, s := range strings.Split(ifacesCSV, ",") {
|
|
s = strings.TrimSpace(s)
|
|
if s != "" {
|
|
allow[s] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
ifaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var subs []subnetEntry
|
|
for _, ifc := range ifaces {
|
|
if ifc.Flags&net.FlagUp == 0 || ifc.Flags&net.FlagLoopback != 0 {
|
|
continue
|
|
}
|
|
if len(allow) > 0 && !allow[ifc.Name] {
|
|
continue
|
|
}
|
|
addrs, err := ifc.Addrs()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("addrs(%s): %w", ifc.Name, err)
|
|
}
|
|
for _, a := range addrs {
|
|
ipnet, ok := a.(*net.IPNet)
|
|
if !ok || ipnet.IP == nil {
|
|
continue
|
|
}
|
|
ip4 := ipnet.IP.To4()
|
|
if ip4 == nil || len(ipnet.Mask) != 4 {
|
|
continue
|
|
}
|
|
n := &net.IPNet{IP: ip4, Mask: ipnet.Mask}
|
|
if !netWithinAny(n, cidrs) {
|
|
continue
|
|
}
|
|
b, ok := ipv4Bcast(n.IP, n.Mask)
|
|
if !ok {
|
|
continue
|
|
}
|
|
subs = append(subs, subnetEntry{
|
|
IfName: ifc.Name,
|
|
IfIdx: ifc.Index,
|
|
Net: n,
|
|
Bcast: b,
|
|
})
|
|
}
|
|
}
|
|
|
|
// de-dupe (same subnet repeated)
|
|
seen := map[string]bool{}
|
|
var out []subnetEntry
|
|
for _, s := range subs {
|
|
k := fmt.Sprintf("%d-%s/%d", s.IfIdx, s.Net.IP.String(), maskBits(s.Net.Mask))
|
|
if !seen[k] {
|
|
seen[k] = true
|
|
out = append(out, s)
|
|
}
|
|
}
|
|
sort.Slice(out, func(i, j int) bool {
|
|
if out[i].IfName == out[j].IfName {
|
|
return binary.BigEndian.Uint32(out[i].Bcast) < binary.BigEndian.Uint32(out[j].Bcast)
|
|
}
|
|
return out[i].IfName < out[j].IfName
|
|
})
|
|
return out, nil
|
|
}
|
|
|
|
func maskBits(m net.IPMask) int {
|
|
ones, _ := m.Size()
|
|
return ones
|
|
}
|
|
|
|
func parseOOB(oob []byte) recvMeta {
|
|
meta := recvMeta{IfIndex: 0, TTL: -1}
|
|
msgs, err := unix.ParseSocketControlMessage(oob)
|
|
if err != nil {
|
|
return meta
|
|
}
|
|
for _, m := range msgs {
|
|
switch {
|
|
case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_PKTINFO:
|
|
// in_pktinfo: ifindex (u32 LE) + spec_dst + addr
|
|
if len(m.Data) >= 12 {
|
|
meta.IfIndex = int(binary.LittleEndian.Uint32(m.Data[0:4]))
|
|
}
|
|
case m.Header.Level == unix.IPPROTO_IP && m.Header.Type == unix.IP_TTL:
|
|
switch len(m.Data) {
|
|
case 1:
|
|
meta.TTL = int(m.Data[0])
|
|
case 4:
|
|
ttlLE := int(binary.LittleEndian.Uint32(m.Data))
|
|
ttlBE := int(binary.BigEndian.Uint32(m.Data))
|
|
if ttlLE >= 0 && ttlLE <= 255 {
|
|
meta.TTL = ttlLE
|
|
} else {
|
|
meta.TTL = ttlBE
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return meta
|
|
}
|
|
|
|
type sender struct {
|
|
ifname string
|
|
conn net.PacketConn
|
|
}
|
|
|
|
func newSender(ifname string, ttl int) (*sender, error) {
|
|
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
|
|
_ = unix.Close(fd)
|
|
return nil, err
|
|
}
|
|
if err = unix.SetsockoptString(fd, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifname); err != nil {
|
|
_ = unix.Close(fd)
|
|
return nil, err
|
|
}
|
|
if err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_TTL, ttl); err != nil {
|
|
_ = unix.Close(fd)
|
|
return nil, err
|
|
}
|
|
|
|
f := os.NewFile(uintptr(fd), "udp-send-"+ifname)
|
|
pc, err := net.FilePacketConn(f)
|
|
_ = f.Close()
|
|
if err != nil {
|
|
_ = unix.Close(fd)
|
|
return nil, err
|
|
}
|
|
return &sender{ifname: ifname, conn: pc}, nil
|
|
}
|
|
|
|
func main() {
|
|
cidrsFlag := flag.String("cidrs", "", "comma-separated IPv4 CIDRs to include (for example 192.168.0.0/16,10.0.0.0/8)")
|
|
ifacesFlag := flag.String("ifaces", "", "optional comma-separated interface allow-list (empty = all UP non-loopback)")
|
|
portsFlag := flag.String("ports", "32410,32412,32413,32414,50222", "comma-separated UDP ports to relay")
|
|
ttlFlag := flag.Int("ttl", 1, "outbound TTL for forwarded packets (use 1)")
|
|
dropTTLFlag := flag.Int("drop-ttl", 1, "drop inbound packets with this TTL (use 1 to drop your forwarded copies)")
|
|
timeoutMS := flag.Int("timeout-ms", 250, "read timeout in ms")
|
|
logLevelFlag := flag.String("log-level", "info", "log level: trace|debug|info|warn|error|fatal|panic")
|
|
flag.Parse()
|
|
|
|
lvl, err := zerolog.ParseLevel(strings.ToLower(strings.TrimSpace(*logLevelFlag)))
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "WARN: bad --log-level %q; using info\n", *logLevelFlag)
|
|
lvl = zerolog.InfoLevel
|
|
}
|
|
lg := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(lvl)
|
|
|
|
cidrs, err := parseCIDRs(*cidrsFlag)
|
|
if err != nil {
|
|
lg.Fatal().Msgf("cidrs: %v", err)
|
|
}
|
|
ports, err := parsePorts(*portsFlag)
|
|
if err != nil {
|
|
lg.Fatal().Msgf("ports: %v", err)
|
|
}
|
|
|
|
subnets, err := discoverSubnets(cidrs, *ifacesFlag)
|
|
if err != nil {
|
|
lg.Fatal().Msgf("discover: %v", err)
|
|
}
|
|
if len(subnets) < 2 {
|
|
lg.Fatal().Msgf("need at least 2 discovered subnets inside -cidrs; got %d", len(subnets))
|
|
}
|
|
|
|
localIPs := localIPv4SetFromSubnets(subnets)
|
|
|
|
// One sender per interface name
|
|
senders := map[string]*sender{}
|
|
for _, s := range subnets {
|
|
if _, ok := senders[s.IfName]; ok {
|
|
continue
|
|
}
|
|
se, err := newSender(s.IfName, *ttlFlag)
|
|
if err != nil {
|
|
lg.Fatal().Msgf("sender(%s): %v", s.IfName, err)
|
|
}
|
|
senders[s.IfName] = se
|
|
}
|
|
defer func() {
|
|
for _, s := range senders {
|
|
_ = s.conn.Close()
|
|
}
|
|
}()
|
|
|
|
lg.Info().Msgf("Discovered subnets in %s:", *cidrsFlag)
|
|
for _, s := range subnets {
|
|
lg.Info().Msgf(" if=%s idx=%d net=%s bcast=%s", s.IfName, s.IfIdx, s.Net.String(), s.Bcast.String())
|
|
}
|
|
lg.Info().Msgf("Relaying UDP ports=%v ttl=%d dropTTL=%d", ports, *ttlFlag, *dropTTLFlag)
|
|
|
|
stop := make(chan os.Signal, 1)
|
|
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
|
|
|
for _, p := range ports {
|
|
port := p
|
|
go func() {
|
|
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
|
if err != nil {
|
|
lg.Fatal().Msgf("socket :%d: %v", port, err)
|
|
}
|
|
defer unix.Close(fd)
|
|
|
|
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
|
|
lg.Fatal().Msgf("reuseaddr :%d: %v", port, err)
|
|
}
|
|
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
|
|
lg.Fatal().Msgf("broadcast :%d: %v", port, err)
|
|
}
|
|
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_PKTINFO, 1); err != nil {
|
|
lg.Fatal().Msgf("pktinfo :%d: %v", port, err)
|
|
}
|
|
if err := unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1); err != nil {
|
|
lg.Fatal().Msgf("recvttl :%d: %v", port, err)
|
|
}
|
|
|
|
sa := &unix.SockaddrInet4{Port: port}
|
|
copy(sa.Addr[:], net.IPv4zero.To4())
|
|
if err := unix.Bind(fd, sa); err != nil {
|
|
lg.Fatal().Msgf("bind :%d: %v", port, err)
|
|
}
|
|
|
|
buf := make([]byte, 65535)
|
|
oob := make([]byte, 256)
|
|
for {
|
|
_ = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{
|
|
Sec: 0,
|
|
Usec: int64(*timeoutMS) * 1000,
|
|
})
|
|
|
|
n, oobn, _, from, err := unix.Recvmsg(fd, buf, oob, 0)
|
|
if err != nil {
|
|
// timeout
|
|
if errors.Is(err, unix.EAGAIN) ||
|
|
errors.Is(err, unix.EWOULDBLOCK) ||
|
|
errors.Is(err, unix.EINTR) {
|
|
continue
|
|
}
|
|
lg.Warn().Msgf("recvmsg :%d: %v (stopping listener goroutine)", port, err)
|
|
return
|
|
}
|
|
meta := parseOOB(oob[:oobn])
|
|
if meta.TTL == *dropTTLFlag || meta.IfIndex == 0 {
|
|
lg.Debug().Msgf("drop :%d: meta ttl=%d ifindex=%d (dropTTL=%d)", port, meta.TTL, meta.IfIndex, *dropTTLFlag)
|
|
continue
|
|
}
|
|
|
|
var srcIP net.IP
|
|
switch fa := from.(type) {
|
|
case *unix.SockaddrInet4:
|
|
srcIP = net.IPv4(fa.Addr[0], fa.Addr[1], fa.Addr[2], fa.Addr[3]).To4()
|
|
default:
|
|
lg.Debug().Msgf("drop :%d: unexpected from addr type %T", port, from)
|
|
continue
|
|
}
|
|
if srcIP == nil {
|
|
continue
|
|
}
|
|
|
|
// if TTL is our forwarded TTL, drop.
|
|
if meta.TTL == *dropTTLFlag && *dropTTLFlag >= 0 {
|
|
lg.Debug().Msgf("drop :%d: ttl matches dropTTL src=%s ttl=%d", port, srcIP.String(), meta.TTL)
|
|
continue
|
|
}
|
|
|
|
// If the packet's source IP is one of ours, it's a rebroadcast
|
|
// we generated. Drop it to avoid looping.
|
|
if u, ok := ip4ToU32(srcIP); ok && localIPs[u] {
|
|
lg.Debug().Msgf("drop :%d: looped packet src=%s", port, srcIP.String())
|
|
continue
|
|
}
|
|
|
|
// Determine which discovered subnet this packet originated from (on the ingress ifindex).
|
|
origin := map[string]bool{} // key: ifidx+net string
|
|
for _, sn := range subnets {
|
|
if sn.IfIdx != meta.IfIndex {
|
|
continue
|
|
}
|
|
if sn.Net.Contains(srcIP) {
|
|
origin[fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())] = true
|
|
}
|
|
}
|
|
|
|
payload := buf[:n]
|
|
|
|
lg.Debug().Msgf("recv :%d: %d bytes from=%s ifindex=%d ttl=%d", port, n, srcIP.String(), meta.IfIndex, meta.TTL)
|
|
|
|
// Fan out to every broadcast target except the origin subnet(s).
|
|
for _, sn := range subnets {
|
|
k := fmt.Sprintf("%d-%s", sn.IfIdx, sn.Net.String())
|
|
if origin[k] {
|
|
continue
|
|
}
|
|
s := senders[sn.IfName]
|
|
if s == nil {
|
|
continue
|
|
}
|
|
dst := &net.UDPAddr{IP: sn.Bcast, Port: port}
|
|
_, _ = s.conn.WriteTo(payload, dst)
|
|
lg.Debug().Msgf("sent :%d: %d bytes to=%s via=%s", port, len(payload), dst.String(), sn.IfName)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
<-stop
|
|
lg.Info().Msg("Shutting down")
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|