123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- package server
- import (
- "net"
- "os"
- "os/signal"
- "strings"
- "git.giaever.org/joachimmg/go-log.git/log"
- "git.giaever.org/joachimmg/m-dns/config"
- "git.giaever.org/joachimmg/m-dns/connection"
- "git.giaever.org/joachimmg/m-dns/errors"
- "git.giaever.org/joachimmg/m-dns/host"
- "git.giaever.org/joachimmg/m-dns/zone"
- "github.com/miekg/dns"
- )
- type Server interface {
- Daemon()
- Close() error
- }
- type server struct {
- zone.Zone
- *net.Interface
- ipv4 connection.UDP
- ipv6 connection.UDP
- running bool
- runCh chan struct{}
- }
- func New(z zone.Zone, iface *net.Interface) (Server, error) {
- s := new(server)
- if s == nil {
- log.Traceln(errors.Server, errors.OutOfMemory)
- return nil, errors.OutOfMemory
- }
- if z == nil {
- log.Traceln(errors.Server, errors.ServerIsNil)
- return nil, errors.ServerIsNil
- }
- s.Zone = z
- s.Interface = iface
- s.ipv4 = connection.New(4)
- s.ipv6 = connection.New(6)
- if err := s.ipv4.ListenMulticast(s.Interface, config.MdnsIPv4Addr); err != nil {
- log.Traceln(errors.Server, config.MdnsIPv4, err)
- }
- if err := s.ipv6.ListenMulticast(s.Interface, config.MdnsIPv6Addr); err != nil {
- log.Traceln(errors.Server, config.MdnsIPv6, err)
- }
- if !s.ipv4.Listening() && s.ipv6.Listening() {
- log.Traceln(errors.Server, errors.ServerNoListenersStarted)
- return nil, errors.ServerNoListenersStarted
- }
- s.runCh = make(chan struct{})
- s.running = true
- go s.recv(s.ipv4)
- go s.recv(s.ipv6)
- return s, nil
- }
- func (s *server) shutdownListener() {
- log.Traceln("Shutdown listener set on ctrl+x")
- c := make(chan os.Signal, 1)
- signal.Notify(c, os.Interrupt)
- go func() {
- for range c {
- s.Close()
- }
- }()
- }
- func (s *server) Close() error {
- log.Traceln(errors.Server, "Closing")
- if !s.running {
- return nil
- }
- if err := s.ipv4.Close(); err != nil {
- return err
- }
- if err := s.ipv6.Close(); err != nil {
- return err
- }
- s.running = false
- close(s.runCh)
- return nil
- }
- func (s *server) Daemon() {
- log.Traceln("Daemon running.")
- go s.shutdownListener()
- <-s.runCh
- log.Traceln("Daemon ending.")
- }
- func (s *server) recv(c connection.UDP) {
- if c == nil {
- return
- }
- buf := make([]byte, config.BufSize)
- for s.running {
- _, addr, err := c.Read(buf)
- if err != nil {
- if !strings.Contains(err.Error(), "i/o timeout") {
- log.Traceln(errors.Server, err)
- }
- continue
- }
- go s.handlePacket(buf, addr)
- }
- }
- func (s *server) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
- if msg == nil {
- return 0, nil
- }
- buf, err := msg.Pack()
- if err != nil {
- return 0, err
- }
- _, t := host.IP(addr.IP).Type()
- if !s.running {
- return 0, nil
- }
- switch t {
- case host.IPv4:
- return s.ipv4.Write(buf, addr)
- case host.IPv6:
- return s.ipv6.Write(buf, addr)
- default:
- log.Traceln(errors.Server, addr, t, errors.ServerUnknownConnectionAddr)
- return 0, errors.ServerUnknownConnectionAddr
- }
- }
- func (s *server) handlePacket(p []byte, addr *net.UDPAddr) {
- msg := new(dns.Msg)
- if err := msg.Unpack(p); err != nil {
- log.Warningln(errors.Server, addr, err)
- }
- umsg, mmsg, err := s.handleMsg(msg)
- if err != nil {
- log.Warningln(errors.Server, addr, err)
- return
- }
- if n, err := s.send(umsg, addr); err != nil {
- log.Warningln(errors.Server, "Wrote", n, err)
- }
- if n, err := s.send(mmsg, addr); err != nil {
- log.Warningln(errors.Server, "Wrote", n, err)
- }
- }
- func (s *server) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
- if msg.Opcode != dns.OpcodeQuery {
- log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode)
- return nil, nil, errors.ServerReceivedNonQueryOpcode
- }
- if msg.Rcode != 0 {
- log.Traceln(errors.Server, errors.ServerReceivedNonZeroRcode)
- return nil, nil, errors.ServerReceivedNonZeroRcode
- }
- if msg.Truncated {
- log.Traceln(errors.Server, errors.ServerReceivedTruncatedSet)
- return nil, nil, errors.ServerReceivedTruncatedSet
- }
- var uAnswer, mAnswer []dns.RR
- for _, q := range msg.Question {
- uRecords, mRecords := s.handleQuestion(q)
- uAnswer = append(uAnswer, uRecords...)
- mAnswer = append(mAnswer, mRecords...)
- }
- return s.handleResponse(msg, true, uAnswer), s.handleResponse(msg, false, mAnswer), nil
- }
- func (s *server) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
- r := s.Records(question)
- if len(r) == 0 {
- return nil, nil
- }
- for i, rec := range r {
- log.Traceln(errors.Server, "Record", i, rec)
- }
- if question.Qclass&(1<<15) != 0 || config.ForceUnicast {
- return r, nil
- }
- return nil, r
- }
- func (s *server) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
- id := uint16(0)
- if uni {
- id = msg.Id
- }
- if len(ans) == 0 {
- return nil
- }
- return &dns.Msg{
- MsgHdr: dns.MsgHdr{
- Id: id,
- Response: true,
- Opcode: dns.OpcodeQuery,
- Authoritative: true,
- },
- Compress: true,
- Answer: ans,
- }
- }
|