Browse Source

Re-structure

Joachim M. Giæver 7 years ago
parent
commit
8763c60767

+ 17 - 0
client/client.go

@@ -0,0 +1,17 @@
+package client
+
+import (
+	"sync"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/zone"
+)
+
+type MDnsClient struct {
+	sync.Mutex
+	ipv4 map[string]*net.UDPConn
+	ipv6 map[string]*net.UDPConn
+
+	running bool
+	runCh   chan struct{}
+}

+ 8 - 0
client/query/query.go

@@ -0,0 +1,8 @@
+package query
+
+import (
+	"git.giaever.org/joachimmg/m-dns/host"
+
+type Query struct {
+	Service	host.HostString
+}

+ 3 - 0
config/config.go

@@ -12,6 +12,9 @@ const (
 	ForceUnicast = false
 
 	DefaultTTL = 120
+
+	BufSize         = 65536
+	BufReadDeadline = 2
 )
 
 var (

BIN
docs/RFC 6762 - Multicast DNS.pdf


BIN
docs/RFC 6763 - DNS-Based Service Discovery.pdf


+ 38 - 9
errors/errors.go

@@ -4,14 +4,43 @@ import (
 	e "errors"
 )
 
+type Prefix string
+
+const (
+	Zone       Prefix = "Zone"
+	Host       Prefix = "Host"
+	HostString Prefix = "HostString"
+	HostIP     Prefix = "HostIP"
+	HostPort   Prefix = "HostPort"
+	Server     Prefix = "Server"
+)
+
 var (
-	ZoneInstanseIsEmpty = e.New("Zone: Instance name is missing.")
-
-	HostServiceIsEmpty            = e.New("Host: Service name is missing.")
-	HostHostnameNotFQDM           = e.New("Host: Hostname is not a fully-qualified domain name.")
-	HostHostnameCouldNotDetermine = e.New("Host: Could not determine hostname")
-	HostIPCouldNotDetermine       = e.New("Host: Could not determine IP address(es) for hostname")
-	HostIPAddressIsInvalid        = e.New("Host: Invalid IP-address in IP-list")
-	HostDomainNotFQDM             = e.New("Host: Domain is not a fully-qualified domain name.")
-	HostPortInvalid               = e.New("Host: Port is missing or invalid.")
+	OutOfMemory = e.New("Out of memory")
+
+	HostIsNil   = e.New("Host is nil.")
+	ServerIsNil = e.New("Server is nil.")
+
+	HostStringIsEmpty           = e.New("String is empty.")
+	HostStringIsInvalid         = e.New("String is invalid.")
+	HostStringIsInvalidInstance = e.New("String is not an instance, e.g <Readable description>.")
+	HostStringIsInvalidService  = e.New("String is not a service, e.g <_service._tcp> or <_service._udp>.")
+	HostStringIsInvalidDomain   = e.New("String is no a valid domain, e.g <local> or <my.domain>.")
+	HostStringIsInvalidHostname = e.New("Hostname is invalid.")
+
+	HostIPIsInvalid = e.New("IP address(es) is not valid for hostname.")
+
+	HostPortIsInvalid = e.New("Port number is invalid (p < 0 || p > 65535).")
+
+	HostTXTExceedsLimit = e.New("TXT record exceed intended size -- 200 bytes or less.")
+
+	ServerNoListenersStarted     = e.New("No multicast listeners started.")
+	ServerReceivedNonQueryOpcode = e.New("Received non-query Opcode")
+	ServerReceivedNonZeroRcode   = e.New("Received non-zero Rcode")
+	ServerReceivedTruncatedSet   = e.New("Reveived trucated bit set")
+	ServerNoResponseForQuestion  = e.New("No response for question.")
 )
+
+func (p Prefix) String() string {
+	return string(p) + ":"
+}

+ 68 - 0
example/example.go

@@ -0,0 +1,68 @@
+package main
+
+import (
+	"flag"
+	"fmt"
+	"net"
+	"net/http"
+	"os"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/host"
+	"git.giaever.org/joachimmg/m-dns/server"
+	"git.giaever.org/joachimmg/m-dns/zone"
+)
+
+func init() {
+	flag.Parse()
+}
+
+func main() {
+	hostname, err := os.Hostname()
+
+	if err != nil {
+		log.Errorln(err)
+	}
+
+	txt := []string{
+		"login=true",
+		"admin=/admin",
+		"autosign=",
+	}
+
+	host, err := host.New(
+		"This is _a_ .dotted. instance",
+		"_myservice._tcp",
+		"local",
+		hostname,
+		[]net.IP{net.ParseIP("192.168.1.128")},
+		8080,
+		txt,
+	)
+
+	if err != nil {
+		log.Errorln(err)
+	}
+
+	zone, err := zone.New(host)
+
+	if err != nil {
+		log.Errorln(err)
+	}
+
+	mdnss, err := server.New(zone, nil)
+
+	if err != nil {
+		log.Errorln(err)
+	}
+	log.Traceln(mdnss)
+
+	//mdnss.Daemon()
+
+	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+		fmt.Fprintf(w, "Hello, I'm just here hangig around.")
+	})
+
+	log.Panicln(http.ListenAndServe(":8080", nil))
+
+}

+ 125 - 180
host/host.go

@@ -3,251 +3,196 @@ package host
 import (
 	"fmt"
 	"net"
-	"os"
-	"strconv"
 	"strings"
 
 	"git.giaever.org/joachimmg/go-log.git/log"
 	"git.giaever.org/joachimmg/m-dns/errors"
 )
 
-type HostString string
-type HostIP net.IP
-type HostIPs []HostIP
-type HostIPvType int
-type HostPort int
-
-const (
-	IPv4 HostIPvType = iota << 1
-	IPv6 HostIPvType = iota << 1
-	NoIP HostIPvType = iota << 1
-)
-
-type Host struct {
-	service  HostString
-	domain   HostString
-	hostname HostString
-	ips      HostIPs
+type Host interface {
+	GetInstance() Instance
+	GetService() Service
+	GetDomain() Domain
+	GetHostname() Hostname
+	GetIPs() []HostIP
+	GetPort() HostPort
+	GetTXTs() []string
+	GetServiceAddr() HostString
+	GetInstanceAddr() HostString
+	GetHostnameAddr() HostString
+	GetDiscoveryAddr() HostString
+	String() string
+}
+
+type host struct {
+	instance Instance
+	service  Service
+	domain   Domain
+	hostname Hostname
+	ip       []HostIP
 	port     HostPort
+	txt      []TXT
 }
 
-func New() *Host {
-	return new(Host)
-}
+func New(instance, service, domain, hostname string, ip []net.IP, port int, txt []string) (*host, error) {
 
-func (h *Host) SetService(s string) error {
-	if err := HostString(s).validService(); err != nil {
-		return err
-	}
-	h.service = HostString(s)
-	return nil
-}
+	h := new(host)
 
-func (h *Host) SetDomain(d string) error {
-	if len(d) == 0 {
-		d = "local."
+	if h == nil {
+		log.Traceln(errors.Host, errors.OutOfMemory)
+		return nil, errors.OutOfMemory
 	}
 
-	if err := HostString(d).validDomain(); err != nil {
-		return err
+	var err error
+	if h.instance, err = String(instance).isInstanceVariable(); err != nil {
+		return nil, err
 	}
 
-	h.domain = HostString(d)
-	return nil
-}
+	if h.service, err = String(service).isServiceVariable(); err != nil {
+		return nil, err
+	}
 
-func (h *Host) SetHostname(hn string) error {
-	if len(hn) == 0 {
-		var err error
-		hn, err = os.Hostname()
-		if err != nil {
-			log.Warningln(h, hn, err)
-			return errors.HostHostnameCouldNotDetermine
-		}
-		hn = fmt.Sprintf("%s.", hn)
+	if h.domain, err = String(domain).isDomainVariable(); err != nil {
+		return nil, err
 	}
 
-	if err := HostString(hn).validHostname(); err != nil {
-		return err
+	if h.hostname, err = String(hostname).isHostnameVariable(); err != nil {
+		return nil, err
 	}
 
-	h.hostname = HostString(hn)
-	return nil
-}
+	for _, _ip := range ip {
+		if hip, err := IP(_ip).validForHostname(h.hostname); err == nil {
+			h.ip = append(h.ip, hip)
+		}
+	}
 
-func (h *Host) SetIPs(i []net.IP) error {
-	if len(i) == 0 {
-		var err error
-		i, err = net.LookupIP(fmt.Sprintf("%s%s", h.GetHostname(), h.GetDomain()))
+	if len(h.ip) == 0 {
+		addrs, err := net.InterfaceAddrs()
 
 		if err != nil {
-			log.Warningln(h, err)
-			return errors.HostIPCouldNotDetermine
+			log.Traceln(errors.Host, err)
+			return nil, err
 		}
 
+		for _, addr := range addrs {
+			ip := strings.SplitN(addr.String(), "/", 2)
+			if len(ip) >= 1 {
+				if hip, err := IP(net.ParseIP(ip[0])).validForHostname(h.GetHostname()); err == nil {
+					h.ip = append(h.ip, hip)
+				}
+			}
+		}
 	}
 
-	var ips HostIPs
-
-	for _, ip := range i {
-		ips = append(ips, HostIP(ip))
+	if err := Port(port).isValid(); err != nil {
+		return nil, err
 	}
 
-	if err := ips.validIPs(); err != nil {
-		log.Warningln(err)
-		return err
+	h.port = Port(port)
+
+	for _, t := range txt {
+		ht, err := String(t).isTxtVariable()
+		if err != nil {
+			return nil, err
+		} else if len(ht) != 0 {
+			h.txt = append(h.txt, ht)
+		}
 	}
 
-	h.ips = ips
-	return nil
+	return h, nil
 }
 
-func (h *Host) SetPort(p int) error {
-	if err := HostPort(p).validPort(); err != nil {
-		return err
-	}
-	h.port = HostPort(p)
-	return nil
+func (h *host) GetInstance() Instance {
+	return h.instance
 }
 
-func (h *Host) GetService() HostString {
-	if err := h.service.validService(); err != nil {
-		log.Panicln(err)
-	}
+func (h *host) GetService() Service {
 	return h.service
 }
 
-func (h *Host) GetDomain() HostString {
-	if err := h.domain.validDomain(); err != nil {
-		log.Panicln(err)
-	}
+func (h *host) GetDomain() Domain {
 	return h.domain
 }
 
-func (h *Host) GetHostname() HostString {
-	if err := h.hostname.validHostname(); err != nil {
-		log.Panicln(err)
-	}
+func (h *host) GetHostname() Hostname {
 	return h.hostname
 }
 
-func (h *Host) GetIPs() HostIPs {
-	if err := h.ips.validIPs(); err != nil {
-		log.Panicln(err)
-	}
-	return h.ips
-}
-
-func (h *Host) GetPort() HostPort {
-	if err := h.port.validPort(); err != nil {
-		log.Panicln(err)
-	}
-	return h.port
-}
-
-func (h *Host) String() string {
-	return fmt.Sprintf("Service: %s\nDomain: %s\nHostname: %s\nIPs: %s\nPort: %s\n",
-		h.GetService(),
-		h.GetDomain(),
-		h.GetHostname(),
-		strings.Join(h.GetIPs().String(), ", "),
-		h.GetPort(),
-	)
+func (h *host) GetIPs() []HostIP {
+	return h.ip
 }
 
-func (h HostString) ValidFQDM() bool {
-	if len(h) == 0 {
-		log.Warningln("Cannot be blank", h)
-		return false
-	}
-
-	if h[len(h)-1] != '.' {
-		log.Warningf("Missing trailing '.' in <%s>", h)
+func (h *host) HasIPs() bool {
+	if len(h.GetIPs()) == 0 {
 		return false
 	}
-
-	return true
-}
-
-func (h HostString) validService() error {
-	if len(h) == 0 {
-		log.Warningln("Service is empty", h)
-		return errors.HostServiceIsEmpty
-	}
-	return nil
-}
-
-func (h HostString) validHostname() error {
-	if !h.ValidFQDM() {
-		return errors.HostHostnameNotFQDM
-	}
-	return nil
-}
-
-func (h HostString) validDomain() error {
-	if !h.ValidFQDM() {
-		return errors.HostDomainNotFQDM
+	for _, ip := range h.GetIPs() {
+		if _, t := ip.Type(); t == IPv4 || t == IPv6 {
+			return true
+		}
 	}
-	return nil
-}
-
-func (h HostString) TrimDot() string {
-	return strings.Trim(h.String(), ".")
+	return false
 }
 
-func (h HostString) String() string {
-	return string(h)
-}
-
-func (h HostIPs) validIPs() error {
-	for _, ip := range h {
-		if t := ip.Type(); t == NoIP {
-			return errors.HostIPAddressIsInvalid
-		}
-	}
-	return nil
+func (h *host) GetPort() HostPort {
+	return h.port
 }
 
-func (h HostIPs) String() []string {
+func (h *host) GetTXTs() []string {
 	s := []string{}
-	for _, ip := range h {
-		s = append(s, ip.String())
+	for _, t := range h.txt {
+		s = append(s, t.String())
 	}
 	return s
 }
 
-func (h HostIP) Type() HostIPvType {
-	if h.AsIP().To4() != nil {
-		return IPv4
-	}
-
-	if h.AsIP().To16() != nil {
-		return IPv6
-	}
-
-	return NoIP
+func (h *host) GetServiceAddr() HostString {
+	return String(fmt.Sprintf("%s.%s.", h.GetService(), h.GetDomain()))
 }
 
-func (h HostIP) AsIP() net.IP {
-	return net.IP(h)
+func (h *host) GetInstanceAddr() HostString {
+	return String(fmt.Sprintf("%s.%s", h.GetInstance().EncodedInstance(), h.GetServiceAddr()))
 }
 
-func (h HostIP) String() string {
-	return h.AsIP().String()
+func (h *host) GetHostnameAddr() HostString {
+	return String(fmt.Sprintf("%s.%s.", h.GetHostname(), h.GetDomain()))
 }
 
-func (h HostPort) validPort() error {
-	if h <= 0 {
-		log.Warningf("Port <%d> is invalid.", h)
-		return errors.HostPortInvalid
-	}
-	return nil
+func (h *host) GetDiscoveryAddr() HostString {
+	return String(fmt.Sprintf("_services._dns-sd._udp.%s.", h.GetDomain()))
 }
 
-func (h HostPort) Uint16() uint16 {
-	return uint16(h)
-}
+func (h *host) String() string {
+	ips := []string{}
+	for _, ip := range h.GetIPs() {
+		ips = append(ips, ip.String())
+	}
 
-func (h HostPort) String() string {
-	return strconv.Itoa(int(h))
+	return fmt.Sprintf(`
+Host:	%s
+	+ Instance:	%s
+	+ Service:	%s
+	+ Domain:	%s
+	+ IP(s):
+		* %s
+	+ Port:		%d
+	+ Txt(s):	
+		* %s
+	Addr:
+		+ Instance:	%s
+		+ Service:	%s
+		+ Discovery:	%s
+	Service-protocol:	%s`,
+		h.GetHostname(),
+		h.GetInstance(),
+		h.GetService(),
+		h.GetDomain(),
+		strings.Join(ips, "\n\t\t* "),
+		h.port,
+		strings.Join(h.GetTXTs(), "\n\t\t* "),
+		h.GetInstanceAddr(),
+		h.GetServiceAddr(),
+		h.GetDiscoveryAddr(),
+		h.GetService().RootType()[1:],
+	)
 }

+ 0 - 53
host/host_test.go

@@ -1,53 +0,0 @@
-package host
-
-import (
-	"net"
-	"os"
-	"testing"
-)
-
-func TestHost(t *testing.T) {
-	host := New()
-
-	var ips []net.IP
-
-	hostname, _ := os.Hostname()
-
-	if err := host.SetService(""); err == nil {
-		t.Fatal("Not expecting empty service name to be valid")
-	}
-
-	if err := host.SetService("service"); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := host.SetHostname(hostname); err == nil {
-		t.Fatal("Hostname should contain trailing .")
-	}
-
-	if err := host.SetHostname(hostname + "."); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := host.SetDomain("local"); err == nil {
-		t.Fatal("Domain should contain trailing .")
-	}
-
-	if err := host.SetDomain("local."); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := host.SetIPs(ips); err != nil {
-		t.Fatal(err)
-	}
-
-	if err := host.SetPort(0); err == nil {
-		t.Fatal("Port should not be 0 or less")
-	}
-
-	if err := host.SetPort(4545); err != nil {
-		t.Fatal(err)
-	}
-
-	t.Log(host.String())
-}

+ 111 - 0
host/ip.go

@@ -0,0 +1,111 @@
+package host
+
+import (
+	"net"
+	"os"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/errors"
+)
+
+type IPvType int
+
+const (
+	IPv4 IPvType = iota << 1
+	IPv6 IPvType = iota << 1
+	NoIP IPvType = iota << 1
+)
+
+type HostIP interface {
+	String() string
+	Type() (net.IP, IPvType)
+	AsIP() net.IP
+}
+
+type IP net.IP
+
+func (i IP) String() string {
+	return i.AsIP().String()
+}
+
+func (i IP) AsIP() net.IP {
+	return net.IP(i)
+}
+
+func (i IP) Type() (net.IP, IPvType) {
+	if ip := i.AsIP().To4(); ip != nil {
+		switch ip.String() {
+		case "127.0.0.1",
+			net.IPv4zero.String(),
+			net.IPv4bcast.String(),
+			net.IPv4allsys.String(),
+			net.IPv4allrouter.String():
+			return nil, NoIP
+		default:
+			return ip, IPv4
+		}
+	}
+
+	if ip := i.AsIP().To16(); ip != nil {
+		switch ip.String() {
+		case net.IPv6zero.String(),
+			net.IPv6loopback.String(),
+			net.IPv6unspecified.String(),
+			net.IPv6linklocalallnodes.String(),
+			net.IPv6linklocalallrouters.String(),
+			net.IPv6interfacelocalallnodes.String():
+			return nil, NoIP
+		default:
+			return ip, IPv6
+		}
+	}
+
+	return nil, NoIP
+}
+
+func (i IP) validForHostname(hn HostString) (IP, error) {
+	log.Traceln(errors.HostIP, i, hn)
+
+	if _, t := i.Type(); t == NoIP {
+		log.Traceln(errors.HostIP, errors.HostIPIsInvalid)
+		return nil, errors.HostIPIsInvalid
+	}
+
+	f := false
+	if lhn, _ := os.Hostname(); lhn == hn.String() {
+		addrs, err := net.InterfaceAddrs()
+
+		if err != nil {
+			log.Traceln(errors.HostIP, err)
+			return nil, err
+		}
+
+		for _, addr := range addrs {
+			if addr.(*net.IPNet).Contains(i.AsIP()) {
+				f = true
+				break
+			}
+		}
+
+	} else {
+		addrs, err := net.LookupIP(hn.String())
+
+		if err != nil {
+			log.Traceln(errors.HostIP, err)
+			return nil, err
+		}
+
+		for _, addr := range addrs {
+			if addr.String() == i.String() {
+				f = true
+			}
+		}
+	}
+
+	if f == false {
+		log.Traceln(errors.HostIP, errors.HostIPIsInvalid)
+		return nil, errors.HostIPIsInvalid
+	}
+
+	return i, nil
+}

+ 30 - 0
host/port.go

@@ -0,0 +1,30 @@
+package host
+
+import (
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/errors"
+)
+
+type HostPort interface {
+	Int() int
+	Uint16() uint16
+}
+
+type Port int
+
+func (p Port) Int() int {
+	return int(p)
+}
+
+func (p Port) Uint16() uint16 {
+	return uint16(p)
+}
+
+func (p Port) isValid() error {
+	log.Traceln(errors.HostPort, p)
+	if int(p) < 0 || int(p) > 65535 {
+		return errors.HostPortIsInvalid
+	}
+
+	return nil
+}

+ 176 - 0
host/string.go

@@ -0,0 +1,176 @@
+package host
+
+import (
+	"fmt"
+	"net"
+	"os"
+	"regexp"
+	"strings"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/errors"
+)
+
+type HostString interface {
+	String() string
+}
+
+type Instance interface {
+	HostString
+	EncodedInstance() string
+}
+
+type Service interface {
+	HostString
+	Types() []string
+	RootType() string
+}
+
+type Domain interface {
+	HostString
+}
+
+type Hostname interface {
+	HostString
+}
+
+type TXT interface {
+	HostString
+}
+
+type String string
+
+const (
+	EmptyString String = ""
+)
+
+func (s String) String() string {
+	return string(s)
+}
+
+func (s String) isEmpty() bool {
+	return len(s) == 0
+}
+
+func (s String) isValid() error {
+	log.Traceln(errors.HostString, s)
+	if s.isEmpty() {
+		log.Traceln(errors.HostString, s, errors.HostStringIsEmpty)
+		return errors.HostStringIsEmpty
+	}
+
+	if s[len(s)-1] == '.' {
+		log.Traceln(errors.HostString, s, errors.HostStringIsInvalid)
+		return errors.HostStringIsInvalid
+	}
+
+	return nil
+}
+
+func (s String) isInstanceVariable() (String, error) {
+	if s.isEmpty() {
+		log.Traceln(errors.HostString, errors.HostStringIsInvalidInstance)
+		return EmptyString, errors.HostStringIsInvalidInstance
+	}
+
+	re := regexp.MustCompile(`^[\x20-\x7E]+$`)
+
+	if rs := re.FindAllStringSubmatch(s.String(), 1); len(rs) != 0 {
+		return s, nil
+	}
+
+	log.Traceln(errors.HostString, errors.HostStringIsInvalidInstance)
+	return EmptyString, errors.HostStringIsInvalidInstance
+}
+
+func (s String) isServiceVariable() (String, error) {
+	if err := s.isValid(); err != nil {
+		log.Traceln(errors.HostString, errors.HostStringIsInvalidService, err)
+		return EmptyString, errors.HostStringIsInvalidService
+	}
+
+	// RFC 6763: Service pair _<name>._<type> (including (_sub.+)._name._type)
+	re := regexp.MustCompile(`^((?:(\_[a-z\-]+)\.)+)(?:(\_+(?:tcp|udp))+)$`)
+
+	if rs := re.FindAllStringSubmatch(s.String(), 1); len(rs) != 0 {
+		switch rs[0][3] {
+		case "_tcp", "_udp":
+			return s, nil
+		}
+	}
+
+	log.Traceln(errors.HostString, errors.HostStringIsInvalidService)
+	return EmptyString, errors.HostStringIsInvalidService
+}
+
+func (s String) isDomainVariable() (String, error) {
+	if err := s.isValid(); err != nil {
+		log.Traceln(errors.HostString, errors.HostStringIsInvalidDomain)
+		return EmptyString, errors.HostStringIsInvalidDomain
+	}
+
+	if s == "local" {
+		return s, nil
+	}
+
+	if _, err := net.LookupHost(s.String()); err != nil {
+		log.Traceln(errors.HostString, err)
+		return EmptyString, errors.HostStringIsInvalidDomain
+	}
+
+	return s, nil
+}
+
+func (s String) isHostnameVariable() (String, error) {
+	if err := s.isValid(); err != nil {
+		log.Traceln(errors.HostString, errors.HostStringIsInvalidHostname, err)
+		return EmptyString, errors.HostStringIsInvalidHostname
+	}
+
+	if hostname, err := os.Hostname(); err == nil {
+		if s.String() == hostname {
+			return s, nil
+		}
+	}
+
+	if _, err := net.LookupHost(s.String()); err != nil {
+		log.Traceln(errors.HostString, err)
+		return EmptyString, errors.HostStringIsInvalidHostname
+	}
+
+	return s, nil
+}
+
+func (s String) isTxtVariable() (String, error) {
+	log.Traceln(errors.HostString, s)
+	if s.isEmpty() || s[:1][0] == '=' {
+		return EmptyString, nil
+	}
+
+	if len(s) > 200 {
+		return EmptyString, errors.HostTXTExceedsLimit
+	}
+
+	// RFC 6763: Spaces in key is significant, can include any character(incl. '=') in value
+	re := regexp.MustCompile(`^([\x20-\x3C\x3E-\x7E]+)(?:(\=+)?([\x20-\x7E]+)?)$`)
+
+	if rs := re.FindAllStringSubmatch(s.String(), 1); len(rs) != 0 {
+		return String(fmt.Sprintf("%s%s%s", strings.Trim(rs[0][1], " "), rs[0][2], rs[0][3])), nil
+	}
+
+	return EmptyString, nil
+}
+
+func (s String) EncodedInstance() string {
+	// RFC 6763, 4.3: Must be escaped, except leading _
+	return strings.Replace(regexp.QuoteMeta(s.String()), "_", `\_`, -1)
+}
+
+func (s String) Types() []string {
+	return strings.Split(s.String(), ".")
+}
+
+func (s String) RootType() string {
+	t := s.Types()
+	return t[len(t)-1]
+}

+ 241 - 0
server/server.go

@@ -0,0 +1,241 @@
+package server
+
+import (
+	"net"
+	"os"
+	"os/signal"
+	"strings"
+	"sync"
+	"time"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/config"
+	"git.giaever.org/joachimmg/m-dns/errors"
+	"git.giaever.org/joachimmg/m-dns/zone"
+	"github.com/miekg/dns"
+)
+
+type MDnsServer struct {
+	sync.Mutex
+	zone.Zone
+	*net.Interface
+
+	ipv4 *net.UDPConn
+	ipv6 *net.UDPConn
+
+	running bool
+	runCh   chan struct{}
+}
+
+func New(z zone.Zone, iface *net.Interface) (*MDnsServer, error) {
+	m := new(MDnsServer)
+
+	if m == nil {
+		log.Traceln(errors.Server, errors.OutOfMemory)
+		return nil, errors.OutOfMemory
+	}
+
+	if z == nil {
+		log.Traceln(errors.Server, errors.ServerIsNil)
+		return nil, errors.ServerIsNil
+	}
+
+	m.Zone = z
+	m.Interface = iface
+
+	var err error
+
+	if m.ipv4, err = net.ListenMulticastUDP("udp4", m.Interface, config.MdnsIPv4Addr); err != nil {
+		log.Infoln(errors.Server, config.MdnsIPv4, err)
+	}
+
+	if m.ipv6, err = net.ListenMulticastUDP("udp6", m.Interface, config.MdnsIPv6Addr); err != nil {
+		log.Infoln(errors.Server, config.MdnsIPv6, err)
+	}
+
+	if m.ipv4 == nil && m.ipv6 == nil {
+		log.Traceln(errors.Server, errors.ServerNoListenersStarted)
+		return nil, errors.ServerNoListenersStarted
+	}
+
+	m.runCh = make(chan struct{})
+	m.running = true
+
+	go m.recv(m.ipv4)
+	go m.recv(m.ipv6)
+
+	return m, nil
+}
+
+func (m *MDnsServer) shutdownListener() {
+	log.Traceln("Shutdown listener set on ctrl+x")
+	c := make(chan os.Signal, 1)
+	signal.Notify(c, os.Interrupt)
+	go func() {
+		for range c {
+			m.Shutdown()
+		}
+	}()
+}
+
+func (m *MDnsServer) Shutdown() {
+	log.Traceln("Shutting down MDNS-server")
+	m.Lock()
+	defer m.Unlock()
+
+	if !m.running {
+		return
+	}
+
+	if m.ipv4 != nil {
+		m.ipv4.Close()
+	}
+
+	if m.ipv6 != nil {
+		m.ipv6.Close()
+	}
+
+	m.running = false
+	close(m.runCh)
+}
+
+func (m *MDnsServer) Daemon() {
+	log.Traceln("Daemon running.")
+	go m.shutdownListener()
+	<-m.runCh
+	log.Traceln("Daemon ending.")
+}
+
+func (m *MDnsServer) recv(i *net.UDPConn) {
+	if i == nil {
+		return
+	}
+
+	buf := make([]byte, config.BufSize)
+
+	for m.running {
+		m.Lock()
+		i.SetReadDeadline(time.Now().Add(config.BufReadDeadline * time.Second))
+		_, addr, err := i.ReadFromUDP(buf)
+		m.Unlock()
+
+		if err != nil {
+			if !strings.Contains(err.Error(), "i/o timeout") {
+				log.Traceln(errors.Server, err)
+			}
+			continue
+		}
+
+		go m.handlePacket(buf, addr)
+
+	}
+}
+
+func (m *MDnsServer) send(msg *dns.Msg, addr *net.UDPAddr) (int, error) {
+
+	if msg == nil {
+		return 0, nil
+	}
+
+	buf, err := msg.Pack()
+
+	if err != nil {
+		return 0, err
+	}
+
+	if addr.IP.To4() != nil {
+		return m.ipv4.WriteToUDP(buf, addr)
+	}
+
+	return m.ipv6.WriteToUDP(buf, addr)
+}
+
+func (m *MDnsServer) handlePacket(p []byte, addr *net.UDPAddr) {
+	msg := new(dns.Msg)
+
+	if err := msg.Unpack(p); err != nil {
+		log.Warningln(errors.Server, addr, err)
+	}
+
+	umsg, mmsg, err := m.handleMsg(msg)
+
+	if err != nil {
+		log.Warningln(errors.Server, addr, err)
+		return
+	}
+
+	if n, err := m.send(umsg, addr); err != nil {
+		log.Warningln(errors.Server, "Wrote", n, err)
+	}
+
+	if n, err := m.send(mmsg, addr); err != nil {
+		log.Warningln(errors.Server, "Wrote", n, err)
+	}
+}
+
+func (m *MDnsServer) handleMsg(msg *dns.Msg) (*dns.Msg, *dns.Msg, error) {
+	if msg.Opcode != dns.OpcodeQuery {
+		log.Traceln(errors.Server, errors.ServerReceivedNonQueryOpcode)
+		return nil, nil, errors.ServerReceivedNonQueryOpcode
+	}
+
+	if msg.Rcode != 0 {
+		log.Traceln(errors.Server, errors.ServerReceivedNonZeroRcode)
+		return nil, nil, errors.ServerReceivedNonZeroRcode
+	}
+
+	if msg.Truncated {
+		log.Traceln(errors.Server, errors.ServerReceivedTruncatedSet)
+		return nil, nil, errors.ServerReceivedTruncatedSet
+	}
+
+	var uAnswer, mAnswer []dns.RR
+
+	for _, q := range msg.Question {
+		uRecords, mRecords := m.handleQuestion(q)
+		uAnswer = append(uAnswer, uRecords...)
+		mAnswer = append(mAnswer, mRecords...)
+	}
+
+	return m.handleResponse(msg, true, uAnswer), m.handleResponse(msg, false, mAnswer), nil
+}
+
+func (m *MDnsServer) handleQuestion(question dns.Question) ([]dns.RR, []dns.RR) {
+	r := m.Records(question)
+
+	if len(r) == 0 {
+		return nil, nil
+	}
+
+	for i, rec := range r {
+		log.Traceln(errors.Server, "Record", i, rec)
+	}
+
+	if question.Qclass&(1<<15) != 0 || config.ForceUnicast {
+		return r, nil
+	}
+
+	return nil, r
+}
+
+func (m *MDnsServer) handleResponse(msg *dns.Msg, uni bool, ans []dns.RR) *dns.Msg {
+	id := uint16(0)
+	if uni {
+		id = msg.Id
+	}
+
+	if len(ans) == 0 {
+		return nil
+	}
+
+	return &dns.Msg{
+		MsgHdr: dns.MsgHdr{
+			Id:            id,
+			Response:      true,
+			Opcode:        dns.OpcodeQuery,
+			Authoritative: true,
+		},
+		Compress: true,
+		Answer:   ans,
+	}
+}

+ 59 - 0
server/server_test.go

@@ -0,0 +1,59 @@
+package server
+
+import (
+	"net"
+	"os"
+	"testing"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/host"
+	"git.giaever.org/joachimmg/m-dns/zone"
+)
+
+func TestServerInit(t *testing.T) {
+	hostname, err := os.Hostname()
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	txt := []string{
+		"=ignore",
+		"key=value pair",
+		"key=",
+		"key   =   value=pair",
+		"k  e  y = v l a e u",
+		"--key",
+	}
+
+	host, err := host.New(
+		"This is info about \\my own service.",
+		"_http._tcp",
+		"local",
+		hostname,
+		[]net.IP{},
+		8001,
+		txt,
+	)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	zone, err := zone.New(host)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	log.Traceln(zone)
+
+	mdnss, err := New(zone, nil)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Log(mdnss)
+	mdnss.Daemon()
+	t.Fail()
+}

+ 172 - 171
zone/zone.go

@@ -2,8 +2,9 @@ package zone
 
 import (
 	"fmt"
-	"net"
+	"sync"
 
+	"git.giaever.org/joachimmg/go-log.git/log"
 	"git.giaever.org/joachimmg/m-dns/config"
 	"git.giaever.org/joachimmg/m-dns/errors"
 	"git.giaever.org/joachimmg/m-dns/host"
@@ -11,75 +12,154 @@ import (
 )
 
 type Zone interface {
-	Records(q dns.Question)
+	Records(q dns.Question) []dns.RR
 }
 
-type ZoneRecord struct {
-	Instance host.HostString
-	TXT      []string
+type Resource struct {
+	sync.Mutex
+	host.Host
+}
+
+func (r *Resource) qType(q dns.Question) string {
+	switch q.Qtype {
+	case dns.TypeANY:
+		return "TypeANY"
+	case dns.TypePTR:
+		return "TypePTR"
+	case dns.TypeSRV:
+		return "TypeSRV"
+	case dns.TypeA:
+		return "TypeA"
+	case dns.TypeAAAA:
+		return "TypeAAAA"
+	case dns.TypeTXT:
+		return "TypeTXT"
+	default:
+		return fmt.Sprintf("%d", q.Qtype)
+	}
+}
+
+func New(h host.Host) (*Resource, error) {
+
+	r := new(Resource)
+
+	if r == nil {
+		log.Traceln(errors.Zone, errors.OutOfMemory)
+		return nil, errors.OutOfMemory
+	}
 
-	sAddr string
-	iAddr string
-	eAddr string
+	if h == nil {
+		log.Traceln(errors.Zone, errors.HostIsNil)
+		return nil, errors.HostIsNil
+	}
+
+	r.Host = h
 
-	*host.Host
+	return r, nil
 }
 
-func New(instance, service, domain, hostname string, port int, ips []net.IP, txt []string) (*ZoneRecord, error) {
-	zr := new(ZoneRecord)
-	zr.Host = new(host.Host)
+func (r *Resource) hdr(n string, rt uint16) dns.RR_Header {
+	return dns.RR_Header{
+		Name:   n,
+		Rrtype: rt,
+		Class:  dns.ClassINET,
+		Ttl:    config.DefaultTTL,
+	}
+}
 
-	if len(instance) == 0 {
-		return nil, errors.ZoneInstanseIsEmpty
+func (r *Resource) ptrRecord(q dns.Question, p host.HostString) dns.RR {
+	r.Lock()
+	defer r.Unlock()
+	return &dns.PTR{
+		Hdr: r.hdr(q.Name, dns.TypePTR),
+		Ptr: p.String(),
 	}
+}
 
-	zr.Instance = host.HostString(instance)
+func (r *Resource) srvRecord(q dns.Question) dns.RR {
+	r.Lock()
+	defer r.Unlock()
+	log.Traceln("srvRecord")
+	return &dns.SRV{
+		Hdr:      r.hdr(q.Name, dns.TypeSRV),
+		Priority: 10,
+		Weight:   1,
+		Port:     r.GetPort().Uint16(),
+		Target:   r.GetHostnameAddr().String(),
+	}
+}
 
-	if err := zr.SetService(service); err != nil {
-		return nil, err
+func (r *Resource) txtRecord(q dns.Question) dns.RR {
+	r.Lock()
+	defer r.Unlock()
+	return &dns.TXT{
+		Hdr: r.hdr(q.Name, dns.TypeTXT),
+		Txt: r.GetTXTs(),
 	}
+}
 
-	if err := zr.SetDomain(domain); err != nil {
-		return nil, err
+func (r *Resource) records(rr ...dns.RR) []dns.RR {
+	rrn := make([]dns.RR, 0)
+	if rr == nil {
+		return rrn
 	}
+	return append(rrn, rr...)
+}
 
-	if err := zr.SetHostname(hostname); err != nil {
-		return nil, err
+func (r *Resource) aRecord(q dns.Question, ip host.HostIP) dns.RR {
+	r.Lock()
+	defer r.Unlock()
+	return &dns.A{
+		Hdr: r.hdr(r.GetHostnameAddr().String(), dns.TypeA),
+		A:   ip.AsIP(),
 	}
+}
 
-	if err := zr.SetPort(port); err != nil {
-		return nil, err
+func (r *Resource) aaaaRecord(q dns.Question, ip host.HostIP) dns.RR {
+	r.Lock()
+	defer r.Unlock()
+	return &dns.AAAA{
+		Hdr:  r.hdr(r.GetHostnameAddr().String(), dns.TypeAAAA),
+		AAAA: ip.AsIP(),
+	}
+}
+
+func (r *Resource) axRecords(q dns.Question, atype uint16) []dns.RR {
+	a := r.records(nil)[:0]
+	for _, ip := range r.GetIPs() {
+		switch atype {
+		case dns.TypeA:
+			if _, ipType := ip.Type(); ipType == host.IPv4 {
+				a = append(a, r.aRecord(q, ip))
+			}
+		case dns.TypeAAAA:
+			if _, ipType := ip.Type(); ipType == host.IPv6 {
+				a = append(a, r.aaaaRecord(q, ip))
+			}
+		}
 	}
 
-	zr.TXT = txt
+	if len(a) == 0 {
+		return nil
+	}
 
-	zr.sAddr = fmt.Sprintf("%s.%s.",
-		zr.GetService().TrimDot(),
-		zr.GetDomain().TrimDot(),
-	)
-	zr.iAddr = fmt.Sprintf("%s.%s.%s.",
-		zr.Instance.TrimDot,
-		zr.GetService().TrimDot(),
-		zr.GetDomain().TrimDot(),
-	)
-	zr.eAddr = fmt.Sprintf("_services._dns-ds._udp.%s.",
-		zr.GetDomain().TrimDot(),
-	)
-	return zr, nil
+	return a
 }
 
-func (zr *ZoneRecord) Records(q dns.Question) []dns.RR {
+// Records: Return DNS records based on the Question
+func (r *Resource) Records(q dns.Question) []dns.RR {
+	log.Traceln(errors.Zone, "Records", q.Name, r.qType(q))
 	switch q.Name {
-	case zr.sAddr:
-		return zr.sRecords(q)
-	case zr.iAddr:
-		return zr.iRecords(q)
-	case zr.eAddr:
-		return zr.sEnum(q)
-	case zr.GetHostname().String():
+	case r.GetInstanceAddr().String(): // RFC 6763, 13.3 "instance"._<service>.<domain>
+		return r.iRecords(q)
+	case r.GetServiceAddr().String(): // RFC 6763, 13.1 <_service>.<domain>
+		return r.sRecords(q)
+	case r.GetDiscoveryAddr().String():
+		return r.dRecords(q)
+	case r.GetHostnameAddr().String():
 		switch q.Qtype {
-		case dns.TypeA, dns.TypeAAAA:
-			return zr.iRecords(q)
+		case dns.TypeANY, dns.TypeA, dns.TypeAAAA:
+			return r.iRecords(q)
 		}
 		fallthrough
 	default:
@@ -87,146 +167,67 @@ func (zr *ZoneRecord) Records(q dns.Question) []dns.RR {
 	}
 }
 
-func (zr *ZoneRecord) sRecords(q dns.Question) []dns.RR {
+// iRecords: Instance records
+func (r *Resource) iRecords(q dns.Question) []dns.RR {
 	switch q.Qtype {
-	case dns.TypeANY, dns.TypePTR:
-		sr := []dns.RR{
-			&dns.PTR{
-				Hdr: dns.RR_Header{
-					Name:   q.Name,
-					Rrtype: dns.TypePTR,
-					Class:  dns.ClassINET,
-					Ttl:    config.DefaultTTL,
-				},
-				Ptr: zr.iAddr,
-			},
-		}
-
-		return append(sr, zr.Records(
-			dns.Question{
-				Name:  zr.iAddr,
-				Qtype: dns.TypeANY,
-			},
-		)...)
-
+	case dns.TypeANY:
+		return append(r.iRecords(dns.Question{
+			Name:  r.GetInstanceAddr().String(),
+			Qtype: dns.TypeSRV,
+		}), r.iRecords(dns.Question{
+			Name:  r.GetInstanceAddr().String(),
+			Qtype: dns.TypeTXT,
+		})...)
+	case dns.TypeA:
+		return r.axRecords(q, dns.TypeA)
+	case dns.TypeAAAA:
+		return r.axRecords(q, dns.TypeAAAA)
+	case dns.TypeSRV:
+		return append(
+			r.records(r.srvRecord(q)),
+			append(
+				r.iRecords(dns.Question{
+					Name:  r.GetInstanceAddr().String(),
+					Qtype: dns.TypeA,
+				}),
+				r.iRecords(dns.Question{
+					Name:  r.GetInstanceAddr().String(),
+					Qtype: dns.TypeAAAA,
+				})...,
+			)...,
+		)
+	case dns.TypeTXT:
+		return r.records(r.txtRecord(q))
 	default:
+		log.Traceln(errors.Zone, "None iRecord for", r.qType(q))
 		return nil
 	}
 }
 
-func (zr *ZoneRecord) iRecords(q dns.Question) []dns.RR {
+// sRecords: Service records
+func (r *Resource) sRecords(q dns.Question) []dns.RR {
 	switch q.Qtype {
-	case dns.TypeANY:
-		ir := zr.Records(
-			dns.Question{
-				Name:  zr.iAddr,
-				Qtype: dns.TypeSRV,
-			},
+	case dns.TypeANY, dns.TypePTR:
+		return append(
+			r.records(r.ptrRecord(q, r.GetInstanceAddr())),
+			r.iRecords(dns.Question{
+				Name:  r.GetInstanceAddr().String(),
+				Qtype: dns.TypeANY,
+			})...,
 		)
-
-		return append(ir, zr.Records(
-			dns.Question{
-				Name:  zr.iAddr,
-				Qtype: dns.TypeTXT,
-			},
-		)...)
-	case dns.TypeA:
-		var ir []dns.RR
-		for _, ip := range zr.GetIPs() {
-			switch ip.Type() {
-			case host.IPv4:
-				ir = append(ir, &dns.A{
-					Hdr: dns.RR_Header{
-						Name:   zr.GetHostname().String(),
-						Rrtype: dns.TypeA,
-						Class:  dns.ClassINET,
-						Ttl:    config.DefaultTTL,
-					},
-				})
-			case host.IPv6:
-				continue
-			}
-		}
-		return ir
-	case dns.TypeAAAA:
-		var ir []dns.RR
-		for _, ip := range zr.GetIPs() {
-			switch ip.Type() {
-			case host.IPv4:
-				continue
-			case host.IPv6:
-				ir = append(ir, &dns.AAAA{
-					Hdr: dns.RR_Header{
-						Name:   zr.GetHostname().String(),
-						Rrtype: dns.TypeAAAA,
-						Class:  dns.ClassINET,
-						Ttl:    config.DefaultTTL,
-					},
-					AAAA: ip.AsIP().To16(),
-				})
-			}
-		}
-		return ir
-	case dns.TypeSRV:
-		ir := []dns.RR{
-			&dns.SRV{
-				Hdr: dns.RR_Header{
-					Name:   q.Name,
-					Rrtype: dns.TypeSRV,
-					Class:  dns.ClassINET,
-					Ttl:    config.DefaultTTL,
-				},
-				Priority: 10,
-				Weight:   1,
-				Port:     zr.GetPort().Uint16(),
-				Target:   zr.GetHostname().String(),
-			},
-		}
-
-		ir = append(ir, zr.Records(
-			dns.Question{
-				Name:  zr.iAddr,
-				Qtype: dns.TypeA,
-			},
-		)...)
-
-		return append(ir, zr.Records(
-			dns.Question{
-				Name:  zr.iAddr,
-				Qtype: dns.TypeAAAA,
-			},
-		)...)
-	case dns.TypeTXT:
-		return []dns.RR{
-			&dns.TXT{
-				Hdr: dns.RR_Header{
-					Name:   q.Name,
-					Rrtype: dns.TypeTXT,
-					Class:  dns.ClassINET,
-					Ttl:    config.DefaultTTL,
-				},
-			},
-		}
 	default:
+		log.Traceln(errors.Zone, "None sRecord for", r.qType(q))
 		return nil
 	}
 }
 
-func (zr *ZoneRecord) sEnum(q dns.Question) []dns.RR {
+// dRecords: DNS-SD / Discovery aka enumerate
+func (r *Resource) dRecords(q dns.Question) []dns.RR {
 	switch q.Qtype {
 	case dns.TypeANY, dns.TypePTR:
-		return []dns.RR{
-			&dns.PTR{
-				Hdr: dns.RR_Header{
-					Name:   q.Name,
-					Rrtype: dns.TypePTR,
-					Class:  dns.ClassINET,
-					Ttl:    config.DefaultTTL,
-				},
-				Ptr: zr.sAddr,
-			},
-		}
+		return r.records(r.ptrRecord(q, r.GetServiceAddr()))
 	default:
+		log.Traceln(errors.Zone, "None dRecord for", r.qType(q))
 		return nil
 	}
 }

+ 37 - 3
zone/zone_test.go

@@ -1,16 +1,50 @@
 package zone
 
 import (
+	"net"
+	"os"
 	"testing"
+
+	"git.giaever.org/joachimmg/m-dns/host"
 )
 
-func TestZone(t *testing.T) {
-	txt := []string{"Info text"}
-	zone, err := New("instance", "_foo._bar", "", "", 8000, nil, txt)
+func TestZoneInit(t *testing.T) {
+
+	hostname, err := os.Hostname()
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	txt := []string{
+		"=ignore",
+		"key=value pair",
+		"key=",
+		"key   =   value=pair",
+		"k  e  y = v l a e u",
+		"--key",
+	}
+
+	host, err := host.New(
+		"This is info about \\my own service.",
+		"_http._tcp",
+		"local",
+		hostname,
+		[]net.IP{},
+		8001,
+		txt,
+	)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	zone, err := New(host)
 
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	t.Log(zone)
+	t.Fail()
 }