server.go 4.7 KB


  1. package server
  2. import (
  3. "net"
  4. "os"
  5. "os/signal"
  6. "strings"
  7. "git.giaever.org/joachimmg/go-log.git/log"
  8. "git.giaever.org/joachimmg/m-dns/config"
  9. "git.giaever.org/joachimmg/m-dns/connection"
  10. "git.giaever.org/joachimmg/m-dns/errors"
  11. "git.giaever.org/joachimmg/m-dns/host"
  12. "git.giaever.org/joachimmg/m-dns/zone"
  13. "github.com/miekg/dns"
  14. )
  15. type Server interface {
  16. Daemon()
  17. Close() error
  18. }
  19. type server struct {
  20. zone.Zone
  21. *net.Interface
  22. ipv4 connection.UDP
  23. ipv6 connection.UDP
  24. running bool
  25. runCh chan struct{}
  26. }
  27. func New(z zone.Zone, iface *net.Interface) (Server, error) {
  28. s := new(server)
  29. if s == nil {
  30. log.Traceln(errors.Server, errors.OutOfMemory)
  31. return nil, errors.OutOfMemory
  32. }
  33. if z == nil {
  34. log.Traceln(errors.Server, errors.ServerIsNil)
  35. return nil, errors.ServerIsNil
  36. }
  37. s.Zone = z
  38. s.Interface = iface
  39. s.ipv4 = connection.New(4)
  40. s.ipv6 = connection.New(6)
  41. if err := s.ipv4.ListenMulticast(s.Interface, config.MdnsIPv4Addr); err != nil {
  42. log.Traceln(errors.Server, config.MdnsIPv4, err)
  43. }
  44. if err := s.ipv6.ListenMulticast(s.Interface, config.MdnsIPv6Addr); err != nil {
  45. log.Traceln(errors.Server, config.MdnsIPv6, err)
  46. }
  47. if !s.ipv4.Listening() && s.ipv6.Listening() {
  48. log.Traceln(errors.Server, errors.ServerNoListenersStarted)
  49. return nil, errors.ServerNoListenersStarted
  50. }
  51. s.runCh = make(chan struct{})
  52. s.running = true
  53. go s.recv(s.ipv4)
  54. go s.recv(s.ipv6)
  55. return s, nil
  56. }
  57. func (s *server) shutdownListener() {
  58. log.Traceln("Shutdown listener set on ctrl+x")
  59. c := make(chan os.Signal, 1)
  60. signal.Notify(c, os.Interrupt)
  61. go func() {
  62. for range c {
  63. s.Close()
  64. }
  65. }()
  66. }
  67. func (s *server) Close() error {
  68. log.Traceln(errors.Server, "Closing")
  69. if !s.running {
  70. return nil
  71. }
  72. if err := s.ipv4.Close(); err != nil {
  73. return err
  74. }
  75. if err := s.ipv6.Close(); err != nil {
  76. return err
  77. }
  78. s.running = false
  79. close(s.runCh)
  80. return nil
  81. }
  82. func (s *server) Daemon() {
  83. log.Traceln("Daemon running.")
  84. go s.shutdownListener()
  85. <-s.runCh
  86. log.Traceln("Daemon ending.")
  87. }
  88. func (s *server) recv(c connection.UDP) {
  89. if c == nil {
  90. return
  91. }
  92. buf := make([]byte, config.BufSize)
  93. for s.running {
  94. _, addr, err := c.Read(buf)
  95. if err != nil {
  96. if !strings.Contains(err.Error(), "i/o timeout") {
  97. log.Traceln(errors.Server, err)
  98. }
  99. continue
  100. }
  101. go s.handlePacket(buf, addr)
  102. }
  103. }
  104. func (s *server) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
  105. if msg == nil {
  106. return 0, nil
  107. }
  108. buf, err := msg.Pack()
  109. if err != nil {
  110. return 0, err
  111. }
  112. _, t := host.IP(addr.IP).Type()
  113. if !s.running {
  114. return 0, nil
  115. }
  116. switch t {
  117. case host.IPv4:
  118. return s.ipv4.Write(buf, addr)
  119. case host.IPv6:
  120. return s.ipv6.Write(buf, addr)
  121. default:
  122. log.Traceln(errors.Server, addr, t, errors.ServerUnknownConnectionAddr)
  123. return 0, errors.ServerUnknownConnectionAddr
  124. }
  125. }
  126. func (s *server) handlePacket(p []byte, addr *net.UDPAddr) {
  127. msg := new(dns.Msg)
  128. if err := msg.Unpack(p); err != nil {
  129. log.Warningln(errors.Server, addr, err)
  130. }
  131. umsg, mmsg, err := s.handleMsg(msg)
  132. if err != nil {
  133. log.Warningln(errors.Server, addr, err)
  134. return
  135. }
  136. if n, err := s.send(umsg, addr); err != nil {
  137. log.Warningln(errors.Server, "Wrote", n, err)
  138. }
  139. if n, err := s.send(mmsg, addr); err != nil {
  140. log.Warningln(errors.Server, "Wrote", n, err)
  141. }
  142. }
  143. func (s *server) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
  144. if msg.Opcode != dns.OpcodeQuery {
  145. log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode)
  146. return nil, nil, errors.ServerReceivedNonQueryOpcode
  147. }
  148. if msg.Rcode != 0 {
  149. log.Traceln(errors.Server, errors.ServerReceivedNonZeroRcode)
  150. return nil, nil, errors.ServerReceivedNonZeroRcode
  151. }
  152. if msg.Truncated {
  153. log.Traceln(errors.Server, errors.ServerReceivedTruncatedSet)
  154. return nil, nil, errors.ServerReceivedTruncatedSet
  155. }
  156. var uAnswer, mAnswer []dns.RR
  157. for _, q := range msg.Question {
  158. uRecords, mRecords := s.handleQuestion(q)
  159. uAnswer = append(uAnswer, uRecords...)
  160. mAnswer = append(mAnswer, mRecords...)
  161. }
  162. return s.handleResponse(msg, true, uAnswer), s.handleResponse(msg, false, mAnswer), nil
  163. }
  164. func (s *server) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
  165. r := s.Records(question)
  166. if len(r) == 0 {
  167. return nil, nil
  168. }
  169. for i, rec := range r {
  170. log.Traceln(errors.Server, "Record", i, rec)
  171. }
  172. if question.Qclass&(1<<15) != 0 || config.ForceUnicast {
  173. return r, nil
  174. }
  175. return nil, r
  176. }
  177. func (s *server) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
  178. id := uint16(0)
  179. if uni {
  180. id = msg.Id
  181. }
  182. if len(ans) == 0 {
  183. return nil
  184. }
  185. return &dns.Msg{
  186. MsgHdr: dns.MsgHdr{
  187. Id: id,
  188. Response: true,
  189. Opcode: dns.OpcodeQuery,
  190. Authoritative: true,
  191. },
  192. Compress: true,
  193. Answer: ans,
  194. }
  195. }