server.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package server
  2. import (
  3. "net"
  4. "os"
  5. "os/signal"
  6. "strings"
  7. "sync"
  8. "time"
  9. "git.giaever.org/joachimmg/go-log.git/log"
  10. "git.giaever.org/joachimmg/m-dns/config"
  11. "git.giaever.org/joachimmg/m-dns/errors"
  12. "git.giaever.org/joachimmg/m-dns/zone"
  13. "github.com/miekg/dns"
  14. )
  15. type MDnsServer struct {
  16. sync.Mutex
  17. zone.Zone
  18. *net.Interface
  19. ipv4 *net.UDPConn
  20. ipv6 *net.UDPConn
  21. running bool
  22. runCh chan struct{}
  23. }
  24. func New(z zone.Zone, iface *net.Interface) (*MDnsServer, error) {
  25. m := new(MDnsServer)
  26. if m == nil {
  27. log.Traceln(errors.Server, errors.OutOfMemory)
  28. return nil, errors.OutOfMemory
  29. }
  30. if z == nil {
  31. log.Traceln(errors.Server, errors.ServerIsNil)
  32. return nil, errors.ServerIsNil
  33. }
  34. m.Zone = z
  35. m.Interface = iface
  36. var err error
  37. if m.ipv4, err = net.ListenMulticastUDP("udp4", m.Interface, config.MdnsIPv4Addr); err != nil {
  38. log.Infoln(errors.Server, config.MdnsIPv4, err)
  39. }
  40. if m.ipv6, err = net.ListenMulticastUDP("udp6", m.Interface, config.MdnsIPv6Addr); err != nil {
  41. log.Infoln(errors.Server, config.MdnsIPv6, err)
  42. }
  43. if m.ipv4 == nil && m.ipv6 == nil {
  44. log.Traceln(errors.Server, errors.ServerNoListenersStarted)
  45. return nil, errors.ServerNoListenersStarted
  46. }
  47. m.runCh = make(chan struct{})
  48. m.running = true
  49. go m.recv(m.ipv4)
  50. go m.recv(m.ipv6)
  51. return m, nil
  52. }
  53. func (m *MDnsServer) shutdownListener() {
  54. log.Traceln("Shutdown listener set on ctrl+x")
  55. c := make(chan os.Signal, 1)
  56. signal.Notify(c, os.Interrupt)
  57. go func() {
  58. for range c {
  59. m.Shutdown()
  60. }
  61. }()
  62. }
  63. func (m *MDnsServer) Shutdown() {
  64. log.Traceln("Shutting down MDNS-server")
  65. m.Lock()
  66. defer m.Unlock()
  67. if !m.running {
  68. return
  69. }
  70. if m.ipv4 != nil {
  71. m.ipv4.Close()
  72. }
  73. if m.ipv6 != nil {
  74. m.ipv6.Close()
  75. }
  76. m.running = false
  77. close(m.runCh)
  78. }
  79. func (m *MDnsServer) Daemon() {
  80. log.Traceln("Daemon running.")
  81. go m.shutdownListener()
  82. <-m.runCh
  83. log.Traceln("Daemon ending.")
  84. }
  85. func (m *MDnsServer) recv(i *net.UDPConn) {
  86. if i == nil {
  87. return
  88. }
  89. buf := make([]byte, config.BufSize)
  90. for m.running {
  91. m.Lock()
  92. i.SetReadDeadline(time.Now().Add(config.BufReadDeadline * time.Second))
  93. _, addr, err := i.ReadFromUDP(buf)
  94. m.Unlock()
  95. if err != nil {
  96. if !strings.Contains(err.Error(), "i/o timeout") {
  97. log.Traceln(errors.Server, err)
  98. }
  99. continue
  100. }
  101. go m.handlePacket(buf, addr)
  102. }
  103. }
  104. func (m *MDnsServer) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
  105. if msg == nil {
  106. return 0, nil
  107. }
  108. buf, err := msg.Pack()
  109. if err != nil {
  110. return 0, err
  111. }
  112. if addr.IP.To4() != nil {
  113. return m.ipv4.WriteToUDP(buf, addr)
  114. }
  115. return m.ipv6.WriteToUDP(buf, addr)
  116. }
  117. func (m *MDnsServer) handlePacket(p []byte, addr *net.UDPAddr) {
  118. msg := new(dns.Msg)
  119. if err := msg.Unpack(p); err != nil {
  120. log.Warningln(errors.Server, addr, err)
  121. }
  122. umsg, mmsg, err := m.handleMsg(msg)
  123. if err != nil {
  124. log.Warningln(errors.Server, addr, err)
  125. return
  126. }
  127. if n, err := m.send(umsg, addr); err != nil {
  128. log.Warningln(errors.Server, "Wrote", n, err)
  129. }
  130. if n, err := m.send(mmsg, addr); err != nil {
  131. log.Warningln(errors.Server, "Wrote", n, err)
  132. }
  133. }
  134. func (m *MDnsServer) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
  135. if msg.Opcode != dns.OpcodeQuery {
  136. log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode)
  137. return nil, nil, errors.ServerReceivedNonQueryOpcode
  138. }
  139. if msg.Rcode != 0 {
  140. log.Traceln(errors.Server, errors.ServerReceivedNonZeroRcode)
  141. return nil, nil, errors.ServerReceivedNonZeroRcode
  142. }
  143. if msg.Truncated {
  144. log.Traceln(errors.Server, errors.ServerReceivedTruncatedSet)
  145. return nil, nil, errors.ServerReceivedTruncatedSet
  146. }
  147. var uAnswer, mAnswer []dns.RR
  148. for _, q := range msg.Question {
  149. uRecords, mRecords := m.handleQuestion(q)
  150. uAnswer = append(uAnswer, uRecords...)
  151. mAnswer = append(mAnswer, mRecords...)
  152. }
  153. return m.handleResponse(msg, true, uAnswer), m.handleResponse(msg, false, mAnswer), nil
  154. }
  155. func (m *MDnsServer) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
  156. r := m.Records(question)
  157. if len(r) == 0 {
  158. return nil, nil
  159. }
  160. for i, rec := range r {
  161. log.Traceln(errors.Server, "Record", i, rec)
  162. }
  163. if question.Qclass&(1<<15) != 0 || config.ForceUnicast {
  164. return r, nil
  165. }
  166. return nil, r
  167. }
  168. func (m *MDnsServer) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
  169. id := uint16(0)
  170. if uni {
  171. id = msg.Id
  172. }
  173. if len(ans) == 0 {
  174. return nil
  175. }
  176. return &dns.Msg{
  177. MsgHdr: dns.MsgHdr{
  178. Id: id,
  179. Response: true,
  180. Opcode: dns.OpcodeQuery,
  181. Authoritative: true,
  182. },
  183. Compress: true,
  184. Answer: ans,
  185. }
  186. }