Browse Source

Init commit. Added zone (records) and host-specific.

Joachim M. Giæver 8 years ago
parent
commit
9b58b3f5cd
6 changed files with 597 additions and 0 deletions
  1. 26 0
      config/config.go
  2. 17 0
      errors/errors.go
  3. 253 0
      host/host.go
  4. 53 0
      host/host_test.go
  5. 232 0
      zone/zone.go
  6. 16 0
      zone/zone_test.go

+ 26 - 0
config/config.go

@@ -0,0 +1,26 @@
+package config
+
+import (
+	"net"
+)
+
+const (
+	MdnsIPv4 = "224.0.0.251"
+	MdnsIPv6 = "ff02::fb"
+	MdnsPort = 5353
+
+	ForceUnicast = false
+
+	DefaultTTL = 120
+)
+
+var (
+	MdnsIPv4Addr = &net.UDPAddr{
+		IP:   net.ParseIP(MdnsIPv4),
+		Port: MdnsPort,
+	}
+	MdnsIPv6Addr = &net.UDPAddr{
+		IP:   net.ParseIP(MdnsIPv6),
+		Port: MdnsPort,
+	}
+)

+ 17 - 0
errors/errors.go

@@ -0,0 +1,17 @@
+package errors
+
+import (
+	e "errors"
+)
+
+var (
+	ZoneInstanseIsEmpty = e.New("Zone: Instance name is missing.")
+
+	HostServiceIsEmpty            = e.New("Host: Service name is missing.")
+	HostHostnameNotFQDM           = e.New("Host: Hostname is not a fully-qualified domain name.")
+	HostHostnameCouldNotDetermine = e.New("Host: Could not determine hostname")
+	HostIPCouldNotDetermine       = e.New("Host: Could not determine IP address(es) for hostname")
+	HostIPAddressIsInvalid        = e.New("Host: Invalid IP-address in IP-list")
+	HostDomainNotFQDM             = e.New("Host: Domain is not a fully-qualified domain name.")
+	HostPortInvalid               = e.New("Host: Port is missing or invalid.")
+)

+ 253 - 0
host/host.go

@@ -0,0 +1,253 @@
+package host
+
+import (
+	"fmt"
+	"net"
+	"os"
+	"strconv"
+	"strings"
+
+	"git.giaever.org/joachimmg/go-log.git/log"
+	"git.giaever.org/joachimmg/m-dns/errors"
+)
+
+type HostString string
+type HostIP net.IP
+type HostIPs []HostIP
+type HostIPvType int
+type HostPort int
+
+const (
+	IPv4 HostIPvType = iota << 1
+	IPv6 HostIPvType = iota << 1
+	NoIP HostIPvType = iota << 1
+)
+
+type Host struct {
+	service  HostString
+	domain   HostString
+	hostname HostString
+	ips      HostIPs
+	port     HostPort
+}
+
+func New() *Host {
+	return new(Host)
+}
+
+func (h *Host) SetService(s string) error {
+	if err := HostString(s).validService(); err != nil {
+		return err
+	}
+	h.service = HostString(s)
+	return nil
+}
+
+func (h *Host) SetDomain(d string) error {
+	if len(d) == 0 {
+		d = "local."
+	}
+
+	if err := HostString(d).validDomain(); err != nil {
+		return err
+	}
+
+	h.domain = HostString(d)
+	return nil
+}
+
+func (h *Host) SetHostname(hn string) error {
+	if len(hn) == 0 {
+		var err error
+		hn, err = os.Hostname()
+		if err != nil {
+			log.Warningln(h, hn, err)
+			return errors.HostHostnameCouldNotDetermine
+		}
+		hn = fmt.Sprintf("%s.", hn)
+	}
+
+	if err := HostString(hn).validHostname(); err != nil {
+		return err
+	}
+
+	h.hostname = HostString(hn)
+	return nil
+}
+
+func (h *Host) SetIPs(i []net.IP) error {
+	if len(i) == 0 {
+		var err error
+		i, err = net.LookupIP(fmt.Sprintf("%s%s", h.GetHostname(), h.GetDomain()))
+
+		if err != nil {
+			log.Warningln(h, err)
+			return errors.HostIPCouldNotDetermine
+		}
+
+	}
+
+	var ips HostIPs
+
+	for _, ip := range i {
+		ips = append(ips, HostIP(ip))
+	}
+
+	if err := ips.validIPs(); err != nil {
+		log.Warningln(err)
+		return err
+	}
+
+	h.ips = ips
+	return nil
+}
+
+func (h *Host) SetPort(p int) error {
+	if err := HostPort(p).validPort(); err != nil {
+		return err
+	}
+	h.port = HostPort(p)
+	return nil
+}
+
+func (h *Host) GetService() HostString {
+	if err := h.service.validService(); err != nil {
+		log.Panicln(err)
+	}
+	return h.service
+}
+
+func (h *Host) GetDomain() HostString {
+	if err := h.domain.validDomain(); err != nil {
+		log.Panicln(err)
+	}
+	return h.domain
+}
+
+func (h *Host) GetHostname() HostString {
+	if err := h.hostname.validHostname(); err != nil {
+		log.Panicln(err)
+	}
+	return h.hostname
+}
+
+func (h *Host) GetIPs() HostIPs {
+	if err := h.ips.validIPs(); err != nil {
+		log.Panicln(err)
+	}
+	return h.ips
+}
+
+func (h *Host) GetPort() HostPort {
+	if err := h.port.validPort(); err != nil {
+		log.Panicln(err)
+	}
+	return h.port
+}
+
+func (h *Host) String() string {
+	return fmt.Sprintf("Service: %s\nDomain: %s\nHostname: %s\nIPs: %s\nPort: %s\n",
+		h.GetService(),
+		h.GetDomain(),
+		h.GetHostname(),
+		strings.Join(h.GetIPs().String(), ", "),
+		h.GetPort(),
+	)
+}
+
+func (h HostString) ValidFQDM() bool {
+	if len(h) == 0 {
+		log.Warningln("Cannot be blank", h)
+		return false
+	}
+
+	if h[len(h)-1] != '.' {
+		log.Warningf("Missing trailing '.' in <%s>", h)
+		return false
+	}
+
+	return true
+}
+
+func (h HostString) validService() error {
+	if len(h) == 0 {
+		log.Warningln("Service is empty", h)
+		return errors.HostServiceIsEmpty
+	}
+	return nil
+}
+
+func (h HostString) validHostname() error {
+	if !h.ValidFQDM() {
+		return errors.HostHostnameNotFQDM
+	}
+	return nil
+}
+
+func (h HostString) validDomain() error {
+	if !h.ValidFQDM() {
+		return errors.HostDomainNotFQDM
+	}
+	return nil
+}
+
+func (h HostString) TrimDot() string {
+	return strings.Trim(h.String(), ".")
+}
+
+func (h HostString) String() string {
+	return string(h)
+}
+
+func (h HostIPs) validIPs() error {
+	for _, ip := range h {
+		if t := ip.Type(); t == NoIP {
+			return errors.HostIPAddressIsInvalid
+		}
+	}
+	return nil
+}
+
+func (h HostIPs) String() []string {
+	s := []string{}
+	for _, ip := range h {
+		s = append(s, ip.String())
+	}
+	return s
+}
+
+func (h HostIP) Type() HostIPvType {
+	if h.AsIP().To4() != nil {
+		return IPv4
+	}
+
+	if h.AsIP().To16() != nil {
+		return IPv6
+	}
+
+	return NoIP
+}
+
+func (h HostIP) AsIP() net.IP {
+	return net.IP(h)
+}
+
+func (h HostIP) String() string {
+	return h.AsIP().String()
+}
+
+func (h HostPort) validPort() error {
+	if h <= 0 {
+		log.Warningf("Port <%d> is invalid.", h)
+		return errors.HostPortInvalid
+	}
+	return nil
+}
+
+func (h HostPort) Uint16() uint16 {
+	return uint16(h)
+}
+
+func (h HostPort) String() string {
+	return strconv.Itoa(int(h))
+}

+ 53 - 0
host/host_test.go

@@ -0,0 +1,53 @@
+package host
+
+import (
+	"net"
+	"os"
+	"testing"
+)
+
+func TestHost(t *testing.T) {
+	host := New()
+
+	var ips []net.IP
+
+	hostname, _ := os.Hostname()
+
+	if err := host.SetService(""); err == nil {
+		t.Fatal("Not expecting empty service name to be valid")
+	}
+
+	if err := host.SetService("service"); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := host.SetHostname(hostname); err == nil {
+		t.Fatal("Hostname should contain trailing .")
+	}
+
+	if err := host.SetHostname(hostname + "."); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := host.SetDomain("local"); err == nil {
+		t.Fatal("Domain should contain trailing .")
+	}
+
+	if err := host.SetDomain("local."); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := host.SetIPs(ips); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := host.SetPort(0); err == nil {
+		t.Fatal("Port should not be 0 or less")
+	}
+
+	if err := host.SetPort(4545); err != nil {
+		t.Fatal(err)
+	}
+
+	t.Log(host.String())
+}

+ 232 - 0
zone/zone.go

@@ -0,0 +1,232 @@
+package zone
+
+import (
+	"fmt"
+	"net"
+
+	"git.giaever.org/joachimmg/m-dns/config"
+	"git.giaever.org/joachimmg/m-dns/errors"
+	"git.giaever.org/joachimmg/m-dns/host"
+	"github.com/miekg/dns"
+)
+
+type Zone interface {
+	Records(q dns.Question)
+}
+
+type ZoneRecord struct {
+	Instance host.HostString
+	TXT      []string
+
+	sAddr string
+	iAddr string
+	eAddr string
+
+	*host.Host
+}
+
+func New(instance, service, domain, hostname string, port int, ips []net.IP, txt []string) (*ZoneRecord, error) {
+	zr := new(ZoneRecord)
+	zr.Host = new(host.Host)
+
+	if len(instance) == 0 {
+		return nil, errors.ZoneInstanseIsEmpty
+	}
+
+	zr.Instance = host.HostString(instance)
+
+	if err := zr.SetService(service); err != nil {
+		return nil, err
+	}
+
+	if err := zr.SetDomain(domain); err != nil {
+		return nil, err
+	}
+
+	if err := zr.SetHostname(hostname); err != nil {
+		return nil, err
+	}
+
+	if err := zr.SetPort(port); err != nil {
+		return nil, err
+	}
+
+	zr.TXT = txt
+
+	zr.sAddr = fmt.Sprintf("%s.%s.",
+		zr.GetService().TrimDot(),
+		zr.GetDomain().TrimDot(),
+	)
+	zr.iAddr = fmt.Sprintf("%s.%s.%s.",
+		zr.Instance.TrimDot,
+		zr.GetService().TrimDot(),
+		zr.GetDomain().TrimDot(),
+	)
+	zr.eAddr = fmt.Sprintf("_services._dns-ds._udp.%s.",
+		zr.GetDomain().TrimDot(),
+	)
+	return zr, nil
+}
+
+func (zr *ZoneRecord) Records(q dns.Question) []dns.RR {
+	switch q.Name {
+	case zr.sAddr:
+		return zr.sRecords(q)
+	case zr.iAddr:
+		return zr.iRecords(q)
+	case zr.eAddr:
+		return zr.sEnum(q)
+	case zr.GetHostname().String():
+		switch q.Qtype {
+		case dns.TypeA, dns.TypeAAAA:
+			return zr.iRecords(q)
+		}
+		fallthrough
+	default:
+		return nil
+	}
+}
+
+func (zr *ZoneRecord) sRecords(q dns.Question) []dns.RR {
+	switch q.Qtype {
+	case dns.TypeANY, dns.TypePTR:
+		sr := []dns.RR{
+			&dns.PTR{
+				Hdr: dns.RR_Header{
+					Name:   q.Name,
+					Rrtype: dns.TypePTR,
+					Class:  dns.ClassINET,
+					Ttl:    config.DefaultTTL,
+				},
+				Ptr: zr.iAddr,
+			},
+		}
+
+		return append(sr, zr.Records(
+			dns.Question{
+				Name:  zr.iAddr,
+				Qtype: dns.TypeANY,
+			},
+		)...)
+
+	default:
+		return nil
+	}
+}
+
+func (zr *ZoneRecord) iRecords(q dns.Question) []dns.RR {
+	switch q.Qtype {
+	case dns.TypeANY:
+		ir := zr.Records(
+			dns.Question{
+				Name:  zr.iAddr,
+				Qtype: dns.TypeSRV,
+			},
+		)
+
+		return append(ir, zr.Records(
+			dns.Question{
+				Name:  zr.iAddr,
+				Qtype: dns.TypeTXT,
+			},
+		)...)
+	case dns.TypeA:
+		var ir []dns.RR
+		for _, ip := range zr.GetIPs() {
+			switch ip.Type() {
+			case host.IPv4:
+				ir = append(ir, &dns.A{
+					Hdr: dns.RR_Header{
+						Name:   zr.GetHostname().String(),
+						Rrtype: dns.TypeA,
+						Class:  dns.ClassINET,
+						Ttl:    config.DefaultTTL,
+					},
+				})
+			case host.IPv6:
+				continue
+			}
+		}
+		return ir
+	case dns.TypeAAAA:
+		var ir []dns.RR
+		for _, ip := range zr.GetIPs() {
+			switch ip.Type() {
+			case host.IPv4:
+				continue
+			case host.IPv6:
+				ir = append(ir, &dns.AAAA{
+					Hdr: dns.RR_Header{
+						Name:   zr.GetHostname().String(),
+						Rrtype: dns.TypeAAAA,
+						Class:  dns.ClassINET,
+						Ttl:    config.DefaultTTL,
+					},
+					AAAA: ip.AsIP().To16(),
+				})
+			}
+		}
+		return ir
+	case dns.TypeSRV:
+		ir := []dns.RR{
+			&dns.SRV{
+				Hdr: dns.RR_Header{
+					Name:   q.Name,
+					Rrtype: dns.TypeSRV,
+					Class:  dns.ClassINET,
+					Ttl:    config.DefaultTTL,
+				},
+				Priority: 10,
+				Weight:   1,
+				Port:     zr.GetPort().Uint16(),
+				Target:   zr.GetHostname().String(),
+			},
+		}
+
+		ir = append(ir, zr.Records(
+			dns.Question{
+				Name:  zr.iAddr,
+				Qtype: dns.TypeA,
+			},
+		)...)
+
+		return append(ir, zr.Records(
+			dns.Question{
+				Name:  zr.iAddr,
+				Qtype: dns.TypeAAAA,
+			},
+		)...)
+	case dns.TypeTXT:
+		return []dns.RR{
+			&dns.TXT{
+				Hdr: dns.RR_Header{
+					Name:   q.Name,
+					Rrtype: dns.TypeTXT,
+					Class:  dns.ClassINET,
+					Ttl:    config.DefaultTTL,
+				},
+			},
+		}
+	default:
+		return nil
+	}
+}
+
+func (zr *ZoneRecord) sEnum(q dns.Question) []dns.RR {
+	switch q.Qtype {
+	case dns.TypeANY, dns.TypePTR:
+		return []dns.RR{
+			&dns.PTR{
+				Hdr: dns.RR_Header{
+					Name:   q.Name,
+					Rrtype: dns.TypePTR,
+					Class:  dns.ClassINET,
+					Ttl:    config.DefaultTTL,
+				},
+				Ptr: zr.sAddr,
+			},
+		}
+	default:
+		return nil
+	}
+}

+ 16 - 0
zone/zone_test.go

@@ -0,0 +1,16 @@
+package zone
+
+import (
+	"testing"
+)
+
+func TestZone(t *testing.T) {
+	txt := []string{"Info text"}
+	zone, err := New("instance", "_foo._bar", "", "", 8000, nil, txt)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Log(zone)
+}