|
@@ -5,32 +5,36 @@ import (
|
|
|
"os"
|
|
|
"os/signal"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
- "time"
|
|
|
|
|
|
"git.giaever.org/joachimmg/go-log.git/log"
|
|
|
"git.giaever.org/joachimmg/m-dns/config"
|
|
|
+ "git.giaever.org/joachimmg/m-dns/connection"
|
|
|
"git.giaever.org/joachimmg/m-dns/errors"
|
|
|
+ "git.giaever.org/joachimmg/m-dns/host"
|
|
|
"git.giaever.org/joachimmg/m-dns/zone"
|
|
|
"github.com/miekg/dns"
|
|
|
)
|
|
|
|
|
|
-type MDnsServer struct {
|
|
|
- sync.Mutex
|
|
|
+type Server interface {
|
|
|
+ Daemon()
|
|
|
+ Close() error
|
|
|
+}
|
|
|
+
|
|
|
+type server struct {
|
|
|
zone.Zone
|
|
|
*net.Interface
|
|
|
|
|
|
- ipv4 *net.UDPConn
|
|
|
- ipv6 *net.UDPConn
|
|
|
+ ipv4 connection.UDP
|
|
|
+ ipv6 connection.UDP
|
|
|
|
|
|
running bool
|
|
|
runCh chan struct{}
|
|
|
}
|
|
|
|
|
|
-func New(z zone.Zone, iface *net.Interface) (*MDnsServer, error) {
|
|
|
- m := new(MDnsServer)
|
|
|
+func New(z zone.Zone, iface *net.Interface) (Server, error) {
|
|
|
+ s := new(server)
|
|
|
|
|
|
- if m == nil {
|
|
|
+ if s == nil {
|
|
|
log.Traceln(errors.Server, errors.OutOfMemory)
|
|
|
return nil, errors.OutOfMemory
|
|
|
}
|
|
@@ -40,84 +44,80 @@ func New(z zone.Zone, iface *net.Interface) (*MDnsServer, error) {
|
|
|
return nil, errors.ServerIsNil
|
|
|
}
|
|
|
|
|
|
- m.Zone = z
|
|
|
- m.Interface = iface
|
|
|
-
|
|
|
- var err error
|
|
|
+ s.Zone = z
|
|
|
+ s.Interface = iface
|
|
|
+ s.ipv4 = connection.New(4)
|
|
|
+ s.ipv6 = connection.New(6)
|
|
|
|
|
|
- if m.ipv4, err = net.ListenMulticastUDP("udp4", m.Interface, config.MdnsIPv4Addr); err != nil {
|
|
|
- log.Infoln(errors.Server, config.MdnsIPv4, err)
|
|
|
+ if err := s.ipv4.ListenMulticast(s.Interface, config.MdnsIPv4Addr); err != nil {
|
|
|
+ log.Traceln(errors.Server, config.MdnsIPv4, err)
|
|
|
}
|
|
|
|
|
|
- if m.ipv6, err = net.ListenMulticastUDP("udp6", m.Interface, config.MdnsIPv6Addr); err != nil {
|
|
|
- log.Infoln(errors.Server, config.MdnsIPv6, err)
|
|
|
+ if err := s.ipv6.ListenMulticast(s.Interface, config.MdnsIPv6Addr); err != nil {
|
|
|
+ log.Traceln(errors.Server, config.MdnsIPv6, err)
|
|
|
}
|
|
|
|
|
|
- if m.ipv4 == nil && m.ipv6 == nil {
|
|
|
+ if !s.ipv4.Listening() && s.ipv6.Listening() {
|
|
|
log.Traceln(errors.Server, errors.ServerNoListenersStarted)
|
|
|
return nil, errors.ServerNoListenersStarted
|
|
|
}
|
|
|
|
|
|
- m.runCh = make(chan struct{})
|
|
|
- m.running = true
|
|
|
+ s.runCh = make(chan struct{})
|
|
|
+ s.running = true
|
|
|
|
|
|
- go m.recv(m.ipv4)
|
|
|
- go m.recv(m.ipv6)
|
|
|
+ go s.recv(s.ipv4)
|
|
|
+ go s.recv(s.ipv6)
|
|
|
|
|
|
- return m, nil
|
|
|
+ return s, nil
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) shutdownListener() {
|
|
|
+func (s *server) shutdownListener() {
|
|
|
log.Traceln("Shutdown listener set on ctrl+x")
|
|
|
c := make(chan os.Signal, 1)
|
|
|
signal.Notify(c, os.Interrupt)
|
|
|
go func() {
|
|
|
for range c {
|
|
|
- m.Shutdown()
|
|
|
+ s.Close()
|
|
|
}
|
|
|
}()
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) Shutdown() {
|
|
|
- log.Traceln("Shutting down MDNS-server")
|
|
|
- m.Lock()
|
|
|
- defer m.Unlock()
|
|
|
-
|
|
|
- if !m.running {
|
|
|
- return
|
|
|
+func (s *server) Close() error {
|
|
|
+ log.Traceln(errors.Server, "Closing")
|
|
|
+ if !s.running {
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
- if m.ipv4 != nil {
|
|
|
- m.ipv4.Close()
|
|
|
+ if err := s.ipv4.Close(); err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- if m.ipv6 != nil {
|
|
|
- m.ipv6.Close()
|
|
|
+ if err := s.ipv6.Close(); err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- m.running = false
|
|
|
- close(m.runCh)
|
|
|
+ s.running = false
|
|
|
+ close(s.runCh)
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) Daemon() {
|
|
|
+func (s *server) Daemon() {
|
|
|
log.Traceln("Daemon running.")
|
|
|
- go m.shutdownListener()
|
|
|
- <-m.runCh
|
|
|
+ go s.shutdownListener()
|
|
|
+ <-s.runCh
|
|
|
log.Traceln("Daemon ending.")
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) recv(i *net.UDPConn) {
|
|
|
- if i == nil {
|
|
|
+func (s *server) recv(c connection.UDP) {
|
|
|
+ if c == nil {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
buf := make([]byte, config.BufSize)
|
|
|
|
|
|
- for m.running {
|
|
|
- m.Lock()
|
|
|
- i.SetReadDeadline(time.Now().Add(config.BufReadDeadline * time.Second))
|
|
|
- _, addr, err := i.ReadFromUDP(buf)
|
|
|
- m.Unlock()
|
|
|
+ for s.running {
|
|
|
+ _, addr, err := c.Read(buf)
|
|
|
|
|
|
if err != nil {
|
|
|
if !strings.Contains(err.Error(), "i/o timeout") {
|
|
@@ -126,12 +126,12 @@ func (m *MDnsServer) recv(i *net.UDPConn) {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- go m.handlePacket(buf, addr)
|
|
|
+ go s.handlePacket(buf, addr)
|
|
|
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
|
|
|
+func (s *server) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
|
|
|
|
|
|
if msg == nil {
|
|
|
return 0, nil
|
|
@@ -143,37 +143,47 @@ func (m *MDnsServer) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
- if addr.IP.To4() != nil {
|
|
|
- return m.ipv4.WriteToUDP(buf, addr)
|
|
|
+ _, t := host.IP(addr.IP).Type()
|
|
|
+
|
|
|
+ if !s.running {
|
|
|
+ return 0, nil
|
|
|
}
|
|
|
|
|
|
- return m.ipv6.WriteToUDP(buf, addr)
|
|
|
+ switch t {
|
|
|
+ case host.IPv4:
|
|
|
+ return s.ipv4.Write(buf, addr)
|
|
|
+ case host.IPv6:
|
|
|
+ return s.ipv6.Write(buf, addr)
|
|
|
+ default:
|
|
|
+ log.Traceln(errors.Server, addr, t, errors.ServerUnknownConnectionAddr)
|
|
|
+ return 0, errors.ServerUnknownConnectionAddr
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) handlePacket(p []byte, addr *net.UDPAddr) {
|
|
|
+func (s *server) handlePacket(p []byte, addr *net.UDPAddr) {
|
|
|
msg := new(dns.Msg)
|
|
|
|
|
|
if err := msg.Unpack(p); err != nil {
|
|
|
log.Warningln(errors.Server, addr, err)
|
|
|
}
|
|
|
|
|
|
- umsg, mmsg, err := m.handleMsg(msg)
|
|
|
+ umsg, mmsg, err := s.handleMsg(msg)
|
|
|
|
|
|
if err != nil {
|
|
|
log.Warningln(errors.Server, addr, err)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if n, err := m.send(umsg, addr); err != nil {
|
|
|
+ if n, err := s.send(umsg, addr); err != nil {
|
|
|
log.Warningln(errors.Server, "Wrote", n, err)
|
|
|
}
|
|
|
|
|
|
- if n, err := m.send(mmsg, addr); err != nil {
|
|
|
+ if n, err := s.send(mmsg, addr); err != nil {
|
|
|
log.Warningln(errors.Server, "Wrote", n, err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
|
|
|
+func (s *server) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
|
|
|
if msg.Opcode != dns.OpcodeQuery {
|
|
|
log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode)
|
|
|
return nil, nil, errors.ServerReceivedNonQueryOpcode
|
|
@@ -192,16 +202,16 @@ func (m *MDnsServer) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
|
|
|
var uAnswer, mAnswer []dns.RR
|
|
|
|
|
|
for _, q := range msg.Question {
|
|
|
- uRecords, mRecords := m.handleQuestion(q)
|
|
|
+ uRecords, mRecords := s.handleQuestion(q)
|
|
|
uAnswer = append(uAnswer, uRecords...)
|
|
|
mAnswer = append(mAnswer, mRecords...)
|
|
|
}
|
|
|
|
|
|
- return m.handleResponse(msg, true, uAnswer), m.handleResponse(msg, false, mAnswer), nil
|
|
|
+ return s.handleResponse(msg, true, uAnswer), s.handleResponse(msg, false, mAnswer), nil
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
|
|
|
- r := m.Records(question)
|
|
|
+func (s *server) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
|
|
|
+ r := s.Records(question)
|
|
|
|
|
|
if len(r) == 0 {
|
|
|
return nil, nil
|
|
@@ -218,7 +228,7 @@ func (m *MDnsServer) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR)
|
|
|
return nil, r
|
|
|
}
|
|
|
|
|
|
-func (m *MDnsServer) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
|
|
|
+func (s *server) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
|
|
|
id := uint16(0)
|
|
|
if uni {
|
|
|
id = msg.Id
|