Browse Source

Restructure and added query-params and connection

Joachim M. Giæver 7 years ago
parent
commit
42e0e0acbe
12 changed files with 487 additions and 142 deletions
  1. 79 7
      client/client.go
  2. 24 0
      client/query/params/params.go
  3. 33 1
      client/query/query.go
  4. 15 2
      config/config.go
  5. 147 0
      connection/connection.go
  6. 7 0
      errors/errors.go
  7. 47 31
      example/example.go
  8. 11 11
      host/host.go
  9. 11 10
      host/ip.go
  10. 39 16
      host/string.go
  11. 73 63
      server/server.go
  12. 1 1
      zone/zone.go

+ 79 - 7
client/client.go

@@ -1,17 +1,89 @@
 package client
 
 import (
-	"sync"
-
 	"git.giaever.org/joachimmg/go-log.git/log"
-	"git.giaever.org/joachimmg/m-dns/zone"
+	"git.giaever.org/joachimmg/m-dns/config"
+	"git.giaever.org/joachimmg/m-dns/connection"
+	"git.giaever.org/joachimmg/m-dns/errors"
+	"git.giaever.org/joachimmg/m-dns/host"
 )
 
-type MDnsClient struct {
-	sync.Mutex
-	ipv4 map[string]*net.UDPConn
-	ipv6 map[string]*net.UDPConn
+type Client interface {
+	Close() error
+	Lookup(service, domain string, instances chan<- *host.Host) error
+}
+
+type client struct {
+	ipv4u connection.UDP
+	ipv6u connection.UDP
+	ipv4m connection.UDP
+	ipv6m connection.UDP
 
 	running bool
 	runCh   chan struct{}
 }
+
+func New() (Client, error) {
+
+	c := new(client)
+
+	if c == nil {
+		log.Traceln(errors.Client, errors.OutOfMemory)
+		return nil, errors.OutOfMemory
+	}
+
+	c.ipv4u = connection.New(4)
+	c.ipv6u = connection.New(6)
+
+	if err := c.ipv4u.Listen(config.ZeroIPv4Addr); err != nil {
+		log.Traceln(errors.Client, config.ZeroIPv4Addr, err)
+	}
+
+	if err := c.ipv6u.Listen(config.ZeroIPv6Addr); err != nil {
+		log.Traceln(errors.Client, config.ZeroIPv6Addr, err)
+	}
+
+	if !c.ipv4u.Listening() && !c.ipv6u.Listening() {
+		log.Traceln(errors.Client, errors.ClientUDPuFailed, config.ZeroIPv4Addr, config.ZeroIPv6Addr)
+		return nil, errors.ClientUDPuFailed
+	}
+
+	c.ipv4m = connection.New(4)
+	c.ipv6m = connection.New(6)
+
+	if err := c.ipv4m.ListenMulticast(nil, config.MdnsIPv4Addr); err != nil {
+		log.Traceln(errors.Client, config.MdnsIPv4Addr, err)
+	}
+
+	if err := c.ipv6m.ListenMulticast(nil, config.MdnsIPv6Addr); err != nil {
+		log.Traceln(errors.Client, config.MdnsIPv6Addr, err)
+	}
+
+	if !c.ipv4m.Listening() && !c.ipv6m.Listening() {
+		log.Traceln(errors.Client, errors.ClientUDPmFailed, config.MdnsIPv4Addr, config.MdnsIPv6Addr)
+		return nil, errors.ClientUDPmFailed
+	}
+
+	return c, nil
+}
+
+func (c *client) Close() error {
+
+	if !c.running {
+		return nil
+	}
+
+	log.Traceln(errors.Client, "Closing")
+	c.ipv4u.Close()
+	c.ipv6u.Close()
+
+	c.ipv4m.Close()
+	c.ipv6m.Close()
+
+	c.running = nil
+	return nil
+}
+
+func (c *client) Lookup(service, domain string, instances chan<- *host.Host) error {
+	return nil
+}

+ 24 - 0
client/query/params/params.go

@@ -0,0 +1,24 @@
+package params
+
+type Params struct {
+	*net.Interface
+
+	Service host.Service
+	Domain  host.Domain
+}
+
+func New(service, domain string, iface *net.Interface) (*Params, error) {
+	p := new(Params)
+
+	if p.service, err = host.String(service).IsServiceVariable(); err != nil {
+		log.Traceln(errors.Params, service, err)
+		return nil, err
+	}
+
+	if p.domain, err = host.String(domain).IsDomainVariable(); err != nil {
+		log.Traceln(errors.Params, domain, err)
+		return nil, err
+	}
+
+	return p, nil
+}

+ 33 - 1
client/query/query.go

@@ -1,8 +1,40 @@
 package query
 
 import (
+	"git.giaever.org/joachimmg/m-dns/errors"
 	"git.giaever.org/joachimmg/m-dns/host"
+	"git.git.giaever.org/joachimmg/m-dns/query/params"
+)
 
 type Query struct {
-	Service	host.HostString
+	params.Params
+
+	unicast
+	timeout   time.Duration
+	instances chan<- *host.Host
+}
+
+func New(p params.Params, unicast bool, timeout int) (Query, error) {
+	q := new(Query)
+
+	if q == nil {
+		log.Traceln(errors.Client, errors.OutOfMemory)
+		return nil, errors.OutOfMemory
+	}
+
+	q.Params = p
+	q.unicast = u
+	q.timeout = t
+
+	return q, nil
+}
+
+func NewDefault(service, domain string) (Query, error) {
+	p, err := params.New(service, domain, nil)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return New(p, config.ForceUnicast, config.QueryTimeout)
 }

+ 15 - 2
config/config.go

@@ -13,8 +13,11 @@ const (
 
 	DefaultTTL = 120
 
-	BufSize         = 65536
-	BufReadDeadline = 2
+	BufSize          = 65536
+	BufReadDeadline  = 2
+	BufWriteDeadline = 2
+
+	QueryTimeout = 1
 )
 
 var (
@@ -26,4 +29,14 @@ var (
 		IP:   net.ParseIP(MdnsIPv6),
 		Port: MdnsPort,
 	}
+
+	ZeroIPv4Addr = &net.UDPAddr{
+		IP:   net.IPv4zero,
+		Port: 0,
+	}
+
+	ZeroIPv6Addr = &net.UDPAddr{
+		IP:   net.IPv6zero,
+		Port: 0,
+	}
 )

+ 147 - 0
connection/connection.go

@@ -0,0 +1,147 @@
+package connection
+
+import (
+	"net"
+	"sync"
+	"time"
+
+	//"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/config"
+)
+
+type Conn interface {
+	Read(b []byte) (int, *net.UDPAddr, error)
+	Write(b []byte, addr *net.UDPAddr) (int, error)
+	Close() error
+	Lock()
+	RLock()
+	WLock()
+	RUnlock()
+	WUnlock()
+	Unlock()
+}
+
+type UDP interface {
+	Conn
+	Listening() bool
+	Listen(addr *net.UDPAddr) error
+	ListenMulticast(iface *net.Interface, addr *net.UDPAddr) error
+}
+
+type conn struct {
+	*net.UDPConn
+
+	r sync.Mutex
+	w sync.Mutex
+}
+
+type UDP4 struct {
+	conn
+}
+
+type UDP6 struct {
+	conn
+}
+
+func New(net int) UDP {
+	if net == 4 {
+		return new(UDP4)
+	}
+	return new(UDP6)
+}
+
+func (c *conn) Lock() {
+	c.RLock()
+	c.WLock()
+}
+
+func (c *conn) RLock() {
+	c.r.Lock()
+}
+
+func (c *conn) WLock() {
+	c.w.Lock()
+}
+
+func (c *conn) Unlock() {
+	c.RUnlock()
+	c.WUnlock()
+}
+
+func (c *conn) RUnlock() {
+	c.r.Unlock()
+}
+
+func (c *conn) WUnlock() {
+	c.w.Unlock()
+}
+
+func (c *conn) Listening() bool {
+	return c.UDPConn != nil
+}
+
+func (c *conn) Read(b []byte) (int, *net.UDPAddr, error) {
+	c.RLock()
+	defer c.RUnlock()
+
+	if !c.Listening() {
+		return 0, nil, nil
+	}
+
+	if config.BufReadDeadline > 0 {
+		c.SetReadDeadline(time.Now().Add(config.BufReadDeadline * time.Second))
+	}
+
+	return c.ReadFromUDP(b)
+}
+
+func (c *conn) Write(b []byte, addr *net.UDPAddr) (int, error) {
+	c.WLock()
+	defer c.WUnlock()
+
+	if !c.Listening() {
+		return 0, nil
+	}
+
+	if config.BufWriteDeadline > 0 {
+		c.SetWriteDeadline(time.Now().Add(config.BufWriteDeadline * time.Second))
+	}
+
+	return c.WriteToUDP(b, addr)
+}
+
+func (c *conn) Close() error {
+	c.Lock()
+	defer c.Unlock()
+	if !c.Listening() {
+		return nil
+	}
+
+	err := c.UDPConn.Close()
+	c.UDPConn = nil
+	return err
+}
+
+func (u *UDP4) Listen(addr *net.UDPAddr) error {
+	var err error
+	u.UDPConn, err = net.ListenUDP("udp4", addr)
+	return err
+}
+
+func (u *UDP4) ListenMulticast(iface *net.Interface, addr *net.UDPAddr) error {
+	var err error
+	u.UDPConn, err = net.ListenMulticastUDP("udp4", iface, addr)
+	return err
+}
+
+func (u *UDP6) Listen(addr *net.UDPAddr) error {
+	var err error
+	u.UDPConn, err = net.ListenUDP("udp6", addr)
+	return err
+}
+
+func (u *UDP6) ListenMulticast(iface *net.Interface, addr *net.UDPAddr) error {
+	var err error
+	u.UDPConn, err = net.ListenMulticastUDP("udp6", iface, addr)
+	return err
+}

+ 7 - 0
errors/errors.go

@@ -13,11 +13,17 @@ const (
 	HostIP     Prefix = "HostIP"
 	HostPort   Prefix = "HostPort"
 	Server     Prefix = "Server"
+	Client     Prefix = "Client"
+	Query      Prefix = "Query"
+	Params     Prefix = "Params"
 )
 
 var (
 	OutOfMemory = e.New("Out of memory")
 
+	ClientUDPuFailed = e.New("Failed to bind to any unicast address.")
+	ClientUDPmFailed = e.New("Failed to bind to any multicast address.")
+
 	HostIsNil   = e.New("Host is nil.")
 	ServerIsNil = e.New("Server is nil.")
 
@@ -39,6 +45,7 @@ var (
 	ServerReceivedNonZeroRcode   = e.New("Received non-zero Rcode")
 	ServerReceivedTruncatedSet   = e.New("Reveived trucated bit set")
 	ServerNoResponseForQuestion  = e.New("No response for question.")
+	ServerUnknownConnectionAddr  = e.New("Unknown connection on IP.")
 )
 
 func (p Prefix) String() string {

+ 47 - 31
example/example.go

@@ -8,12 +8,16 @@ import (
 	"os"
 
 	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/client"
 	"git.giaever.org/joachimmg/m-dns/host"
 	"git.giaever.org/joachimmg/m-dns/server"
 	"git.giaever.org/joachimmg/m-dns/zone"
 )
 
+var runServer bool
+
 func init() {
+	flag.BoolVar(&runServer, "s", false, "Server if set to true, client otherwise.")
 	flag.Parse()
 }
 
@@ -24,45 +28,57 @@ func main() {
 		log.Errorln(err)
 	}
 
-	txt := []string{
-		"login=true",
-		"admin=/admin",
-		"autosign=",
-	}
+	if runServer {
+		txt := []string{
+			"login=true",
+			"admin=/admin",
+			"autosign=",
+		}
 
-	host, err := host.New(
-		"This is _a_ .dotted. instance",
-		"_myservice._tcp",
-		"local",
-		hostname,
-		[]net.IP{net.ParseIP("192.168.1.128")},
-		8080,
-		txt,
-	)
+		host, err := host.New(
+			"This is _a_ .dotted. instance",
+			"_myservice._tcp",
+			"local",
+			hostname,
+			[]net.IP{net.ParseIP("192.168.1.128"), net.IPv4zero},
+			8080,
+			txt,
+		)
 
-	if err != nil {
-		log.Errorln(err)
-	}
+		if err != nil {
+			log.Errorln(err)
+		}
 
-	zone, err := zone.New(host)
+		zone, err := zone.New(host)
 
-	if err != nil {
-		log.Errorln(err)
-	}
+		if err != nil {
+			log.Errorln(err)
+		}
 
-	mdnss, err := server.New(zone, nil)
+		mdns, err := server.New(zone, nil)
 
-	if err != nil {
-		log.Errorln(err)
-	}
-	log.Traceln(mdnss)
+		if err != nil {
+			log.Errorln(err)
+		}
+		log.Traceln(mdns)
 
-	//mdnss.Daemon()
+		defer mdns.Close()
+		//mdnss.Daemon()
 
-	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
-		fmt.Fprintf(w, "Hello, I'm just here hangig around.")
-	})
+		http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			fmt.Fprintf(w, "Hello, I'm just here hangig around.")
+		})
 
-	log.Panicln(http.ListenAndServe(":8080", nil))
+		log.Panicln(http.ListenAndServe(":8080", nil))
+	} else {
+		mdns, err := client.New()
 
+		if err != nil {
+			log.Panicln(err)
+		}
+
+		defer mdns.Close()
+
+		log.Traceln(mdns)
+	}
 }

+ 11 - 11
host/host.go

@@ -34,7 +34,7 @@ type host struct {
 	txt      []TXT
 }
 
-func New(instance, service, domain, hostname string, ip []net.IP, port int, txt []string) (*host, error) {
+func New(instance, service, domain, hostname string, ip []net.IP, port int, txt []string) (Host, error) {
 
 	h := new(host)
 
@@ -44,19 +44,19 @@ func New(instance, service, domain, hostname string, ip []net.IP, port int, txt
 	}
 
 	var err error
-	if h.instance, err = String(instance).isInstanceVariable(); err != nil {
+	if h.instance, err = String(instance).IsInstanceVariable(); err != nil {
 		return nil, err
 	}
 
-	if h.service, err = String(service).isServiceVariable(); err != nil {
+	if h.service, err = String(service).IsServiceVariable(); err != nil {
 		return nil, err
 	}
 
-	if h.domain, err = String(domain).isDomainVariable(); err != nil {
+	if h.domain, err = String(domain).IsDomainVariable(); err != nil {
 		return nil, err
 	}
 
-	if h.hostname, err = String(hostname).isHostnameVariable(); err != nil {
+	if h.hostname, err = String(hostname).IsHostnameVariable(); err != nil {
 		return nil, err
 	}
 
@@ -91,10 +91,10 @@ func New(instance, service, domain, hostname string, ip []net.IP, port int, txt
 	h.port = Port(port)
 
 	for _, t := range txt {
-		ht, err := String(t).isTxtVariable()
+		ht, err := String(t).IsTxtVariable()
 		if err != nil {
 			return nil, err
-		} else if len(ht) != 0 {
+		} else if !ht.Empty() { // Silently ignore empty txt entries
 			h.txt = append(h.txt, ht)
 		}
 	}
@@ -147,19 +147,19 @@ func (h *host) GetTXTs() []string {
 }
 
 func (h *host) GetServiceAddr() HostString {
-	return String(fmt.Sprintf("%s.%s.", h.GetService(), h.GetDomain()))
+	return h.GetService().ServiceAddr(h.GetDomain())
 }
 
 func (h *host) GetInstanceAddr() HostString {
-	return String(fmt.Sprintf("%s.%s", h.GetInstance().EncodedInstance(), h.GetServiceAddr()))
+	return h.GetInstance().InstanceAddr(h.GetService(), h.GetDomain())
 }
 
 func (h *host) GetHostnameAddr() HostString {
-	return String(fmt.Sprintf("%s.%s.", h.GetHostname(), h.GetDomain()))
+	return h.GetHostname().HostnameAddr(h.GetDomain())
 }
 
 func (h *host) GetDiscoveryAddr() HostString {
-	return String(fmt.Sprintf("_services._dns-sd._udp.%s.", h.GetDomain()))
+	return h.GetDomain().DiscoveryAddr()
 }
 
 func (h *host) String() string {

+ 11 - 10
host/ip.go

@@ -2,18 +2,19 @@ package host
 
 import (
 	"net"
-	"os"
+	//"os" Was used for lookup IP on hostname
 
 	"git.giaever.org/joachimmg/go-log.git/log"
 	"git.giaever.org/joachimmg/m-dns/errors"
 )
 
-type IPvType int
+type IPvType uint
 
 const (
-	IPv4 IPvType = iota << 1
-	IPv6 IPvType = iota << 1
-	NoIP IPvType = iota << 1
+	NoIP  IPvType = 0x1
+	SysIP IPvType = 0x2
+	IPv4  IPvType = 0x4
+	IPv6  IPvType = 0x8
 )
 
 type HostIP interface {
@@ -40,7 +41,7 @@ func (i IP) Type() (net.IP, IPvType) {
 			net.IPv4bcast.String(),
 			net.IPv4allsys.String(),
 			net.IPv4allrouter.String():
-			return nil, NoIP
+			return ip, SysIP | IPv4
 		default:
 			return ip, IPv4
 		}
@@ -54,7 +55,7 @@ func (i IP) Type() (net.IP, IPvType) {
 			net.IPv6linklocalallnodes.String(),
 			net.IPv6linklocalallrouters.String(),
 			net.IPv6interfacelocalallnodes.String():
-			return nil, NoIP
+			return ip, SysIP | IPv6
 		default:
 			return ip, IPv6
 		}
@@ -66,12 +67,12 @@ func (i IP) Type() (net.IP, IPvType) {
 func (i IP) validForHostname(hn HostString) (IP, error) {
 	log.Traceln(errors.HostIP, i, hn)
 
-	if _, t := i.Type(); t == NoIP {
+	if _, t := i.Type(); t == NoIP || (t&SysIP) == SysIP {
 		log.Traceln(errors.HostIP, errors.HostIPIsInvalid)
 		return nil, errors.HostIPIsInvalid
 	}
 
-	f := false
+	/*f := false
 	if lhn, _ := os.Hostname(); lhn == hn.String() {
 		addrs, err := net.InterfaceAddrs()
 
@@ -105,7 +106,7 @@ func (i IP) validForHostname(hn HostString) (IP, error) {
 	if f == false {
 		log.Traceln(errors.HostIP, errors.HostIPIsInvalid)
 		return nil, errors.HostIPIsInvalid
-	}
+	}*/
 
 	return i, nil
 }

+ 39 - 16
host/string.go

@@ -13,25 +13,30 @@ import (
 
 type HostString interface {
 	String() string
+	Empty() bool
+	dotted() string
 }
 
 type Instance interface {
 	HostString
-	EncodedInstance() string
+	InstanceAddr(sr Service, d Domain) HostString
 }
 
 type Service interface {
 	HostString
 	Types() []string
 	RootType() string
+	ServiceAddr(d Domain) HostString
 }
 
 type Domain interface {
 	HostString
+	DiscoveryAddr() HostString
 }
 
 type Hostname interface {
 	HostString
+	HostnameAddr(d Domain) HostString
 }
 
 type TXT interface {
@@ -48,13 +53,13 @@ func (s String) String() string {
 	return string(s)
 }
 
-func (s String) isEmpty() bool {
+func (s String) Empty() bool {
 	return len(s) == 0
 }
 
 func (s String) isValid() error {
 	log.Traceln(errors.HostString, s)
-	if s.isEmpty() {
+	if s.Empty() {
 		log.Traceln(errors.HostString, s, errors.HostStringIsEmpty)
 		return errors.HostStringIsEmpty
 	}
@@ -67,8 +72,8 @@ func (s String) isValid() error {
 	return nil
 }
 
-func (s String) isInstanceVariable() (String, error) {
-	if s.isEmpty() {
+func (s String) IsInstanceVariable() (Instance, error) {
+	if s.Empty() {
 		log.Traceln(errors.HostString, errors.HostStringIsInvalidInstance)
 		return EmptyString, errors.HostStringIsInvalidInstance
 	}
@@ -83,7 +88,7 @@ func (s String) isInstanceVariable() (String, error) {
 	return EmptyString, errors.HostStringIsInvalidInstance
 }
 
-func (s String) isServiceVariable() (String, error) {
+func (s String) IsServiceVariable() (Service, error) {
 	if err := s.isValid(); err != nil {
 		log.Traceln(errors.HostString, errors.HostStringIsInvalidService, err)
 		return EmptyString, errors.HostStringIsInvalidService
@@ -103,7 +108,7 @@ func (s String) isServiceVariable() (String, error) {
 	return EmptyString, errors.HostStringIsInvalidService
 }
 
-func (s String) isDomainVariable() (String, error) {
+func (s String) IsDomainVariable() (Domain, error) {
 	if err := s.isValid(); err != nil {
 		log.Traceln(errors.HostString, errors.HostStringIsInvalidDomain)
 		return EmptyString, errors.HostStringIsInvalidDomain
@@ -121,16 +126,14 @@ func (s String) isDomainVariable() (String, error) {
 	return s, nil
 }
 
-func (s String) isHostnameVariable() (String, error) {
+func (s String) IsHostnameVariable() (Hostname, error) {
 	if err := s.isValid(); err != nil {
 		log.Traceln(errors.HostString, errors.HostStringIsInvalidHostname, err)
 		return EmptyString, errors.HostStringIsInvalidHostname
 	}
 
-	if hostname, err := os.Hostname(); err == nil {
-		if s.String() == hostname {
-			return s, nil
-		}
+	if hostname, _ := os.Hostname(); s.String() == hostname {
+		return s, nil
 	}
 
 	if _, err := net.LookupHost(s.String()); err != nil {
@@ -141,9 +144,9 @@ func (s String) isHostnameVariable() (String, error) {
 	return s, nil
 }
 
-func (s String) isTxtVariable() (String, error) {
+func (s String) IsTxtVariable() (TXT, error) {
 	log.Traceln(errors.HostString, s)
-	if s.isEmpty() || s[:1][0] == '=' {
+	if s.Empty() || s[:1][0] == '=' {
 		return EmptyString, nil
 	}
 
@@ -161,9 +164,29 @@ func (s String) isTxtVariable() (String, error) {
 	return EmptyString, nil
 }
 
-func (s String) EncodedInstance() string {
+func (s String) dotted() string {
+	return s.String() + "."
+}
+
+func (s String) ServiceAddr(d Domain) HostString {
+	return String(s.dotted() + d.dotted())
+}
+
+func (s String) InstanceAddr(sr Service, d Domain) HostString {
+	return String(s.encodedInstance().dotted() + sr.dotted() + d.dotted())
+}
+
+func (s String) HostnameAddr(d Domain) HostString {
+	return String(s.dotted() + d.dotted())
+}
+
+func (s String) DiscoveryAddr() HostString {
+	return String(String("_services._dns-sd._udp").dotted() + s.dotted())
+}
+
+func (s String) encodedInstance() String {
 	// RFC 6763, 4.3: Must be escaped, except leading _
-	return strings.Replace(regexp.QuoteMeta(s.String()), "_", `\_`, -1)
+	return String(strings.Replace(regexp.QuoteMeta(s.String()), "_", `\_`, -1))
 }
 
 func (s String) Types() []string {

+ 73 - 63
server/server.go

@@ -5,32 +5,36 @@ import (
 	"os"
 	"os/signal"
 	"strings"
-	"sync"
-	"time"
 
 	"git.giaever.org/joachimmg/go-log.git/log"
 	"git.giaever.org/joachimmg/m-dns/config"
+	"git.giaever.org/joachimmg/m-dns/connection"
 	"git.giaever.org/joachimmg/m-dns/errors"
+	"git.giaever.org/joachimmg/m-dns/host"
 	"git.giaever.org/joachimmg/m-dns/zone"
 	"github.com/miekg/dns"
 )
 
-type MDnsServer struct {
-	sync.Mutex
+type Server interface {
+	Daemon()
+	Close() error
+}
+
+type server struct {
 	zone.Zone
 	*net.Interface
 
-	ipv4 *net.UDPConn
-	ipv6 *net.UDPConn
+	ipv4 connection.UDP
+	ipv6 connection.UDP
 
 	running bool
 	runCh   chan struct{}
 }
 
-func New(z zone.Zone, iface *net.Interface) (*MDnsServer, error) {
-	m := new(MDnsServer)
+func New(z zone.Zone, iface *net.Interface) (Server, error) {
+	s := new(server)
 
-	if m == nil {
+	if s == nil {
 		log.Traceln(errors.Server, errors.OutOfMemory)
 		return nil, errors.OutOfMemory
 	}
@@ -40,84 +44,80 @@ func New(z zone.Zone, iface *net.Interface) (*MDnsServer, error) {
 		return nil, errors.ServerIsNil
 	}
 
-	m.Zone = z
-	m.Interface = iface
-
-	var err error
+	s.Zone = z
+	s.Interface = iface
+	s.ipv4 = connection.New(4)
+	s.ipv6 = connection.New(6)
 
-	if m.ipv4, err = net.ListenMulticastUDP("udp4", m.Interface, config.MdnsIPv4Addr); err != nil {
-		log.Infoln(errors.Server, config.MdnsIPv4, err)
+	if err := s.ipv4.ListenMulticast(s.Interface, config.MdnsIPv4Addr); err != nil {
+		log.Traceln(errors.Server, config.MdnsIPv4, err)
 	}
 
-	if m.ipv6, err = net.ListenMulticastUDP("udp6", m.Interface, config.MdnsIPv6Addr); err != nil {
-		log.Infoln(errors.Server, config.MdnsIPv6, err)
+	if err := s.ipv6.ListenMulticast(s.Interface, config.MdnsIPv6Addr); err != nil {
+		log.Traceln(errors.Server, config.MdnsIPv6, err)
 	}
 
-	if m.ipv4 == nil && m.ipv6 == nil {
+	if !s.ipv4.Listening() && s.ipv6.Listening() {
 		log.Traceln(errors.Server, errors.ServerNoListenersStarted)
 		return nil, errors.ServerNoListenersStarted
 	}
 
-	m.runCh = make(chan struct{})
-	m.running = true
+	s.runCh = make(chan struct{})
+	s.running = true
 
-	go m.recv(m.ipv4)
-	go m.recv(m.ipv6)
+	go s.recv(s.ipv4)
+	go s.recv(s.ipv6)
 
-	return m, nil
+	return s, nil
 }
 
-func (m *MDnsServer) shutdownListener() {
+func (s *server) shutdownListener() {
 	log.Traceln("Shutdown listener set on ctrl+x")
 	c := make(chan os.Signal, 1)
 	signal.Notify(c, os.Interrupt)
 	go func() {
 		for range c {
-			m.Shutdown()
+			s.Close()
 		}
 	}()
 }
 
-func (m *MDnsServer) Shutdown() {
-	log.Traceln("Shutting down MDNS-server")
-	m.Lock()
-	defer m.Unlock()
-
-	if !m.running {
-		return
+func (s *server) Close() error {
+	log.Traceln(errors.Server, "Closing")
+	if !s.running {
+		return nil
 	}
 
-	if m.ipv4 != nil {
-		m.ipv4.Close()
+	if err := s.ipv4.Close(); err != nil {
+		return err
 	}
 
-	if m.ipv6 != nil {
-		m.ipv6.Close()
+	if err := s.ipv6.Close(); err != nil {
+		return err
 	}
 
-	m.running = false
-	close(m.runCh)
+	s.running = false
+	close(s.runCh)
+
+	return nil
 }
 
-func (m *MDnsServer) Daemon() {
+func (s *server) Daemon() {
 	log.Traceln("Daemon running.")
-	go m.shutdownListener()
-	<-m.runCh
+	go s.shutdownListener()
+	<-s.runCh
 	log.Traceln("Daemon ending.")
 }
 
-func (m *MDnsServer) recv(i *net.UDPConn) {
-	if i == nil {
+func (s *server) recv(c connection.UDP) {
+	if c == nil {
 		return
 	}
 
 	buf := make([]byte, config.BufSize)
 
-	for m.running {
-		m.Lock()
-		i.SetReadDeadline(time.Now().Add(config.BufReadDeadline * time.Second))
-		_, addr, err := i.ReadFromUDP(buf)
-		m.Unlock()
+	for s.running {
+		_, addr, err := c.Read(buf)
 
 		if err != nil {
 			if !strings.Contains(err.Error(), "i/o timeout") {
@@ -126,12 +126,12 @@ func (m *MDnsServer) recv(i *net.UDPConn) {
 			continue
 		}
 
-		go m.handlePacket(buf, addr)
+		go s.handlePacket(buf, addr)
 
 	}
 }
 
-func (m *MDnsServer) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
+func (s *server) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
 
 	if msg == nil {
 		return 0, nil
@@ -143,37 +143,47 @@ func (m *MDnsServer) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
 		return 0, err
 	}
 
-	if addr.IP.To4() != nil {
-		return m.ipv4.WriteToUDP(buf, addr)
+	_, t := host.IP(addr.IP).Type()
+
+	if !s.running {
+		return 0, nil
 	}
 
-	return m.ipv6.WriteToUDP(buf, addr)
+	switch t {
+	case host.IPv4:
+		return s.ipv4.Write(buf, addr)
+	case host.IPv6:
+		return s.ipv6.Write(buf, addr)
+	default:
+		log.Traceln(errors.Server, addr, t, errors.ServerUnknownConnectionAddr)
+		return 0, errors.ServerUnknownConnectionAddr
+	}
 }
 
-func (m *MDnsServer) handlePacket(p []byte, addr *net.UDPAddr) {
+func (s *server) handlePacket(p []byte, addr *net.UDPAddr) {
 	msg := new(dns.Msg)
 
 	if err := msg.Unpack(p); err != nil {
 		log.Warningln(errors.Server, addr, err)
 	}
 
-	umsg, mmsg, err := m.handleMsg(msg)
+	umsg, mmsg, err := s.handleMsg(msg)
 
 	if err != nil {
 		log.Warningln(errors.Server, addr, err)
 		return
 	}
 
-	if n, err := m.send(umsg, addr); err != nil {
+	if n, err := s.send(umsg, addr); err != nil {
 		log.Warningln(errors.Server, "Wrote", n, err)
 	}
 
-	if n, err := m.send(mmsg, addr); err != nil {
+	if n, err := s.send(mmsg, addr); err != nil {
 		log.Warningln(errors.Server, "Wrote", n, err)
 	}
 }
 
-func (m *MDnsServer) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
+func (s *server) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
 	if msg.Opcode != dns.OpcodeQuery {
 		log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode)
 		return nil, nil, errors.ServerReceivedNonQueryOpcode
@@ -192,16 +202,16 @@ func (m *MDnsServer) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
 	var uAnswer, mAnswer []dns.RR
 
 	for _, q := range msg.Question {
-		uRecords, mRecords := m.handleQuestion(q)
+		uRecords, mRecords := s.handleQuestion(q)
 		uAnswer = append(uAnswer, uRecords...)
 		mAnswer = append(mAnswer, mRecords...)
 	}
 
-	return m.handleResponse(msg, true, uAnswer), m.handleResponse(msg, false, mAnswer), nil
+	return s.handleResponse(msg, true, uAnswer), s.handleResponse(msg, false, mAnswer), nil
 }
 
-func (m *MDnsServer) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
-	r := m.Records(question)
+func (s *server) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
+	r := s.Records(question)
 
 	if len(r) == 0 {
 		return nil, nil
@@ -218,7 +228,7 @@ func (m *MDnsServer) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR)
 	return nil, r
 }
 
-func (m *MDnsServer) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
+func (s *server) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
 	id := uint16(0)
 	if uni {
 		id = msg.Id

+ 1 - 1
zone/zone.go

@@ -39,7 +39,7 @@ func (r *Resource) qType(q dns.Question) string {
 	}
 }
 
-func New(h host.Host) (*Resource, error) {
+func New(h host.Host) (Zone, error) {
 
 	r := new(Resource)