package server import ( "net" "os" "os/signal" "strings" "git.giaever.org/joachimmg/go-log.git/log" "git.giaever.org/joachimmg/m-dns/config" "git.giaever.org/joachimmg/m-dns/connection" "git.giaever.org/joachimmg/m-dns/errors" "git.giaever.org/joachimmg/m-dns/host" "git.giaever.org/joachimmg/m-dns/zone" "github.com/miekg/dns" ) type Server interface { Daemon() Close() error } type server struct { zone.Zone *net.Interface ipv4 connection.UDP ipv6 connection.UDP running bool runCh chan struct{} } func New(z zone.Zone, iface *net.Interface) (Server, error) { s := new(server) if s == nil { log.Traceln(errors.Server, errors.OutOfMemory) return nil, errors.OutOfMemory } if z == nil { log.Traceln(errors.Server, errors.ServerIsNil) return nil, errors.ServerIsNil } s.Zone = z s.Interface = iface s.ipv4 = connection.New(4) s.ipv6 = connection.New(6) if err := s.ipv4.ListenMulticast(s.Interface, config.MdnsIPv4Addr); err != nil { log.Traceln(errors.Server, config.MdnsIPv4, err) } if err := s.ipv6.ListenMulticast(s.Interface, config.MdnsIPv6Addr); err != nil { log.Traceln(errors.Server, config.MdnsIPv6, err) } if !s.ipv4.Listening() && s.ipv6.Listening() { log.Traceln(errors.Server, errors.ServerNoListenersStarted) return nil, errors.ServerNoListenersStarted } s.runCh = make(chan struct{}) s.running = true go s.recv(s.ipv4) go s.recv(s.ipv6) return s, nil } func (s *server) shutdownListener() { log.Traceln("Shutdown listener set on ctrl+x") c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { for range c { s.Close() } }() } func (s *server) Close() error { log.Traceln(errors.Server, "Closing") if !s.running { return nil } if err := s.ipv4.Close(); err != nil { return err } if err := s.ipv6.Close(); err != nil { return err } s.running = false close(s.runCh) return nil } func (s *server) Daemon() { log.Traceln("Daemon running.") go s.shutdownListener() <-s.runCh log.Traceln("Daemon ending.") } func (s *server) recv(c connection.UDP) { if c == nil { return } buf := make([]byte, config.BufSize) for s.running { _, addr, err := c.Read(buf) if err != nil { if !strings.Contains(err.Error(), "i/o timeout") { log.Traceln(errors.Server, err) } continue } go s.handlePacket(buf, addr) } } func (s *server) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) { if msg == nil { return 0, nil } buf, err := msg.Pack() if err != nil { return 0, err } _, t := host.IP(addr.IP).Type() if !s.running { return 0, nil } switch t { case host.IPv4: return s.ipv4.Write(buf, addr) case host.IPv6: return s.ipv6.Write(buf, addr) default: log.Traceln(errors.Server, addr, t, errors.ServerUnknownConnectionAddr) return 0, errors.ServerUnknownConnectionAddr } } func (s *server) handlePacket(p []byte, addr *net.UDPAddr) { msg := new(dns.Msg) if err := msg.Unpack(p); err != nil { log.Warningln(errors.Server, addr, err) } umsg, mmsg, err := s.handleMsg(msg) if err != nil { log.Warningln(errors.Server, addr, err) return } if n, err := s.send(umsg, addr); err != nil { log.Warningln(errors.Server, "Wrote", n, err) } if n, err := s.send(mmsg, addr); err != nil { log.Warningln(errors.Server, "Wrote", n, err) } } func (s *server) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) { if msg.Opcode != dns.OpcodeQuery { log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode) return nil, nil, errors.ServerReceivedNonQueryOpcode } if msg.Rcode != 0 { log.Traceln(errors.Server, errors.ServerReceivedNonZeroRcode) return nil, nil, errors.ServerReceivedNonZeroRcode } if msg.Truncated { log.Traceln(errors.Server, errors.ServerReceivedTruncatedSet) return nil, nil, errors.ServerReceivedTruncatedSet } var uAnswer, mAnswer []dns.RR for _, q := range msg.Question { uRecords, mRecords := s.handleQuestion(q) uAnswer = append(uAnswer, uRecords...) mAnswer = append(mAnswer, mRecords...) } return s.handleResponse(msg, true, uAnswer), s.handleResponse(msg, false, mAnswer), nil } func (s *server) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) { r := s.Records(question) if len(r) == 0 { return nil, nil } for i, rec := range r { log.Traceln(errors.Server, "Record", i, rec) } if question.Qclass&(1<<15) != 0 || config.ForceUnicast { return r, nil } return nil, r } func (s *server) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg { id := uint16(0) if uni { id = msg.Id } if len(ans) == 0 { return nil } return &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: id, Response: true, Opcode: dns.OpcodeQuery, Authoritative: true, }, Compress: true, Answer: ans, } }