// SPDX-License-Identifier: BSD-3-Clause
package main
import (
"fmt"
"os"
)
func main() {
app := prepareCLI()
err := app.Run(os.Args)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
// SPDX-License-Identifier: BSD-3-Clause
package main
import (
"fmt"
"math/rand"
"runtime"
"strings"
"git.froth.zone/sam/awl/conf"
"git.froth.zone/sam/awl/util"
"github.com/miekg/dns"
"github.com/urfave/cli/v2"
"golang.org/x/net/idna"
)
// Do all the magic CLI crap
func prepareCLI() *cli.App {
// Custom version string
cli.VersionPrinter = func(c *cli.Context) {
fmt.Printf("%s version %s, built with %s\n", c.App.Name, c.App.Version, runtime.Version())
}
cli.VersionFlag = &cli.BoolFlag{
Name: "v",
Usage: "show version and exit",
}
cli.HelpFlag = &cli.BoolFlag{
Name: "h",
Usage: "show this help and exit",
}
// Hack to get rid of the annoying default on the CLI
oldFlagStringer := cli.FlagStringer
cli.FlagStringer = func(f cli.Flag) string {
return strings.TrimSuffix(oldFlagStringer(f), " (default: false)")
}
cli.AppHelpTemplate = `{{.Name}} - {{.Usage}}
Usage: {{.HelpName}} name [@server] [record]
<name> can be a name or an IP address
<record> defaults to A
arguments can be in any order
{{if .VisibleFlags}}
Options:
{{range .VisibleFlags}}{{.}}
{{end}}{{end}}`
app := &cli.App{
Name: "awl",
Usage: "drill, writ small",
Version: "v0.2.1",
Flags: []cli.Flag{
&cli.IntFlag{
Name: "port",
Aliases: []string{"p"},
Usage: "`<port>` to make DNS query",
DefaultText: "53 over plain TCP/UDP, 853 over TLS or QUIC",
},
&cli.BoolFlag{
Name: "4",
Usage: "force IPv4",
},
&cli.BoolFlag{
Name: "6",
Usage: "force IPv6",
},
&cli.BoolFlag{
Name: "dnssec",
Aliases: []string{"D"},
Usage: "enable DNSSEC",
},
&cli.BoolFlag{
Name: "json",
Aliases: []string{"j"},
Usage: "return the result(s) as JSON",
},
&cli.BoolFlag{
Name: "short",
Aliases: []string{"s"},
Usage: "print just the results, equivalent to dig +short",
},
&cli.BoolFlag{
Name: "tcp",
Aliases: []string{"t"},
Usage: "use TCP (default: use UDP)",
},
&cli.BoolFlag{
Name: "tls",
Aliases: []string{"T"},
Usage: "use DNS-over-TLS",
},
&cli.BoolFlag{
Name: "https",
Aliases: []string{"H"},
Usage: "use DNS-over-HTTPS",
},
&cli.BoolFlag{
Name: "quic",
Aliases: []string{"Q"},
Usage: "use DNS-over-QUIC",
},
&cli.BoolFlag{
Name: "no-truncate",
Usage: "ignore truncation if a UDP request truncates (default: retry with TCP)",
},
&cli.BoolFlag{
Name: "aa",
Usage: "set AA (Authoratative Answer) flag (default: not set)",
},
&cli.BoolFlag{
Name: "tc",
Usage: "set TC (TrunCated) flag (default: not set)",
},
&cli.BoolFlag{
Name: "z",
Usage: "set Z (Zero) flag (default: not set)",
},
&cli.BoolFlag{
Name: "cd",
Usage: "set CD (Checking Disabled) flag (default: not set)",
},
&cli.BoolFlag{
Name: "no-rd",
Usage: "UNset RD (Recursion Desired) flag (default: set)",
},
&cli.BoolFlag{
Name: "no-ra",
Usage: "UNset RA (Recursion Available) flag (default: set)",
},
&cli.BoolFlag{
Name: "reverse",
Aliases: []string{"x"},
Usage: "do a reverse lookup",
},
},
Action: doQuery,
}
return app
}
// Parse the wildcard arguments, drill style
func parseArgs(args []string) (util.Answers, error) {
var (
resp util.Response
err error
)
for _, arg := range args {
r, ok := dns.StringToType[strings.ToUpper(arg)]
switch {
// If it starts with @, it's a DNS server
case strings.HasPrefix(arg, "@"):
resp.Answers.Server = strings.Split(arg, "@")[1]
case strings.Contains(arg, "."):
resp.Answers.Name, err = idna.ToUnicode(arg)
if err != nil {
return util.Answers{}, err
}
case ok:
// If it's a DNS request, it's a DNS request (obviously)
resp.Answers.Request = r
default:
//else, assume it's a name
resp.Answers.Name, err = idna.ToUnicode(arg)
if err != nil {
return util.Answers{}, err
}
}
}
// If nothing was set, set a default
if resp.Answers.Name == "" {
resp.Answers.Name = "."
if resp.Answers.Request == 0 {
resp.Answers.Request = dns.StringToType["NS"]
}
} else {
if resp.Answers.Request == 0 {
resp.Answers.Request = dns.StringToType["A"]
}
}
if resp.Answers.Server == "" {
resolv, err := conf.GetDNSConfig()
if err != nil { // Query Google by default
resp.Answers.Server = "8.8.4.4"
} else {
resp.Answers.Server = resolv.Servers[rand.Intn(len(resolv.Servers))]
}
}
return util.Answers{Server: resp.Answers.Server, Request: resp.Answers.Request, Name: resp.Answers.Name}, nil
}
// SPDX-License-Identifier: BSD-3-Clause
//go:build !windows
// +build !windows
package conf
import (
"os"
"runtime"
"github.com/miekg/dns"
)
// Get the DNS configuration, either from /etc/resolv.conf or somewhere else
func GetDNSConfig() (*dns.ClientConfig, error) {
if runtime.GOOS == "plan9" {
dat, err := os.ReadFile("/net/ndb")
if err != nil {
return nil, err
}
return getPlan9Config(string(dat))
} else {
return dns.ClientConfigFromFile("/etc/resolv.conf")
}
}
// SPDX-License-Identifier: BSD-3-Clause
package conf
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// Plan 9 stores its network data in /net/ndb, which seems to be formatted a specific way
// Yoink it and use it.
//
// See ndb(7).
func getPlan9Config(str string) (*dns.ClientConfig, error) {
str = strings.ReplaceAll(str, "\n", "")
spl := strings.FieldsFunc(str, splitChars)
var servers []string
for _, option := range spl {
if strings.HasPrefix(option, "dns=") {
servers = append(servers, strings.TrimPrefix(option, "dns="))
}
}
if len(servers) == 0 {
return nil, fmt.Errorf("plan9: no DNS servers found")
}
// TODO: read more about how customizable Plan 9 is
return &dns.ClientConfig{
Servers: servers,
Search: []string{},
Port: "53",
}, nil
}
// Split the string at either space or tabs
func splitChars(r rune) bool {
return r == ' ' || r == '\t'
}
// SPDX-License-Identifier: BSD-3-Clause
package logawl
import (
"fmt"
"io"
"sync"
"sync/atomic"
)
type Level int32
type Logger struct {
Mu sync.Mutex
Level Level
Prefix string
Out io.Writer
buf []byte
isDiscard int32
}
// Stores whatever input value is in mem address of l.level
func (l *Logger) SetLevel(level Level) {
atomic.StoreInt32((*int32)(&l.Level), int32(level))
}
// Mostly nothing
func (l *Logger) GetLevel() Level {
return l.level()
}
// Retrieves whatever was stored in mem address of l.level
func (l *Logger) level() Level {
return Level(atomic.LoadInt32((*int32)(&l.Level)))
}
// Unmarshalls the int value of level for writing the header
func (l *Logger) UnMarshalLevel(lv Level) (string, error) {
switch lv {
case 0:
return "FATAL ", nil
case 1:
return "ERROR ", nil
case 2:
return "INFO ", nil
case 3:
return "DEBUG ", nil
}
return "", fmt.Errorf("invalid log level choice")
}
func (l *Logger) IsLevel(level Level) bool {
return l.level() >= level
}
var AllLevels = []Level{
FatalLevel,
ErrorLevel,
InfoLevel,
DebugLevel,
}
const (
// Fatal logs (will call exit(1))
FatalLevel Level = iota
// Error logs
ErrorLevel
// What is going on level
InfoLevel
// Verbose log level.
DebugLevel
)
// SPDX-License-Identifier: BSD-3-Clause
package logawl
import (
"fmt"
"os"
"sync/atomic"
"time"
)
// Calling New instantiates Logger
//
// Level can be changed to one of the other log levels (FatalLevel, ErrorLevel, InfoLevel, DebugLevel)
func New() *Logger {
return &Logger{
Out: os.Stderr,
Level: InfoLevel, //Default value is InfoLevel
}
}
// Takes any and prints it out to Logger -> Out (io.Writer (default is std.Err))
func (l *Logger) Println(level Level, v ...any) {
if atomic.LoadInt32(&l.isDiscard) != 0 {
return
}
//If verbose is not set --debug etc print _nothing_
if l.IsLevel(level) {
switch level { //Goes through log levels and does stuff based on them (Fatal os.Exit...etc)
case 0:
l.Printer(0, fmt.Sprintln(v...)) //Fatal level
os.Exit(1)
case 1:
l.Printer(1, fmt.Sprintln(v...)) //Error level
os.Exit(2)
case 2:
l.Printer(2, fmt.Sprintln(v...)) //Info level
case 3:
l.Printer(3, fmt.Sprintln(v...)) //Debug level
default:
break
}
}
}
// Formats the log header as such <LogLevel> YYYY/MM/DD HH:MM:SS (local time) <the message to log>
func (l *Logger) formatHeader(buf *[]byte, t time.Time, line int, level Level) error {
if lvl, err := l.UnMarshalLevel(level); err == nil {
// This is ugly but functional
// maybe there can be an append func or something in the future
*buf = append(*buf, lvl...)
year, month, day := t.Date()
*buf = append(*buf, '[')
formatter(buf, year, 4)
*buf = append(*buf, '/')
formatter(buf, int(month), 2)
*buf = append(*buf, '/')
formatter(buf, day, 2)
*buf = append(*buf, ' ')
hour, min, sec := t.Clock()
formatter(buf, hour, 2)
*buf = append(*buf, ':')
formatter(buf, min, 2)
*buf = append(*buf, ':')
formatter(buf, sec, 2)
*buf = append(*buf, ']')
*buf = append(*buf, ':')
*buf = append(*buf, ' ')
} else {
return fmt.Errorf("invalid log level choice")
}
return nil
}
// Printer prints the formatted message directly to stdErr
func (l *Logger) Printer(level Level, s string) error {
now := time.Now()
var line int
l.Mu.Lock()
defer l.Mu.Unlock()
l.buf = l.buf[:0]
l.formatHeader(&l.buf, now, line, level)
l.buf = append(l.buf, s...)
if len(s) == 0 || s[len(s)-1] != '\n' {
l.buf = append(l.buf, '\n')
}
_, err := l.Out.Write(l.buf)
return err
}
// Some line formatting stuff from Golang log stdlib file
//
// Please view https://cs.opensource.google/go/go/+/refs/tags/go1.18.3:src/log/log.go;drc=41e1d9075e428c2fc32d966b3752a3029b620e2c;l=96
//
// Cheap integer to fixed-width decimal ASCII. Give a negative width to avoid zero-padding.
func formatter(buf *[]byte, i int, wid int) {
// Assemble decimal in reverse order.
var b [20]byte
bp := len(b) - 1
for i >= 10 || wid > 1 {
wid--
q := i / 10
b[bp] = byte('0' + i - q*10)
bp--
i = q
}
// i < 10
b[bp] = byte('0' + i)
*buf = append(*buf, b[bp:]...)
}
// Call print directly with Debug level
func (l *Logger) Debug(v ...any) {
l.Println(DebugLevel, v...)
}
// Call print directly with Info level
func (l *Logger) Info(v ...any) {
l.Println(InfoLevel, v...)
}
// Call print directly with Error level
func (l *Logger) Error(v ...any) {
l.Println(ErrorLevel, v...)
}
// Call print directly with Fatal level
func (l *Logger) Fatal(v ...any) {
l.Println(FatalLevel, v...)
}
// SPDX-License-Identifier: BSD-3-Clause
package main
import (
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"time"
"git.froth.zone/sam/awl/logawl"
"git.froth.zone/sam/awl/query"
"git.froth.zone/sam/awl/util"
"github.com/miekg/dns"
"github.com/urfave/cli/v2"
)
func doQuery(c *cli.Context) error {
var (
err error
resp util.Response
isHTTPS bool
Logger = logawl.New() //init logger
)
resp.Answers, err = parseArgs(c.Args().Slice())
if err != nil {
Logger.Error("Unable to parse args")
return err
}
port := c.Int("port")
if c.Bool("debug") {
Logger.SetLevel(3)
}
Logger.Debug("Starting awl")
// If port is not set, set it
if port == 0 {
if c.Bool("tls") || c.Bool("quic") {
port = 853
} else {
port = 53
}
}
if c.Bool("https") || strings.HasPrefix(resp.Answers.Server, "https://") {
// add https:// if it doesn't already exist
if !strings.HasPrefix(resp.Answers.Server, "https://") {
resp.Answers.Server = "https://" + resp.Answers.Server
}
isHTTPS = true
} else {
resp.Answers.Server = net.JoinHostPort(resp.Answers.Server, strconv.Itoa(port))
}
// Process the IP/Phone number so a PTR/NAPTR can be done
if c.Bool("reverse") {
if dns.TypeToString[resp.Answers.Request] == "A" {
resp.Answers.Request = dns.StringToType["PTR"]
}
resp.Answers.Name, err = util.ReverseDNS(resp.Answers.Name, dns.TypeToString[resp.Answers.Request])
if err != nil {
return err
}
}
// if the domain is not canonical, make it canonical
if !strings.HasSuffix(resp.Answers.Name, ".") {
resp.Answers.Name = fmt.Sprintf("%s.", resp.Answers.Name)
}
msg := new(dns.Msg)
msg.SetQuestion(resp.Answers.Name, resp.Answers.Request)
// Make this authoritative (does this do anything?)
if c.Bool("aa") {
msg.Authoritative = true
}
// Set truncated flag (why)
if c.Bool("tc") {
msg.Truncated = true
}
// Set the zero flag if requested (does nothing)
if c.Bool("z") {
Logger.Debug("Setting message to zero")
msg.Zero = true
}
// Disable DNSSEC validation
if c.Bool("cd") {
msg.CheckingDisabled = true
}
// Disable wanting recursion
if c.Bool("no-rd") {
msg.RecursionDesired = false
}
// Disable recursion being available (I don't think this does anything)
if c.Bool("no-ra") {
msg.RecursionAvailable = false
}
// Set DNSSEC if requested
if c.Bool("dnssec") {
Logger.Debug("Using DNSSEC")
msg.SetEdns0(1232, true)
}
var in *dns.Msg
// Make the DNS request
if isHTTPS {
in, resp.Answers.RTT, err = query.ResolveHTTPS(msg, resp.Answers.Server)
} else if c.Bool("quic") {
in, resp.Answers.RTT, err = query.ResolveQUIC(msg, resp.Answers.Server)
} else {
d := new(dns.Client)
// Set TCP/UDP, depending on flags
if c.Bool("tcp") || c.Bool("tls") {
d.Net = "tcp"
} else {
d.Net = "udp"
}
// Set IPv4 or IPv6, depending on flags
switch {
case c.Bool("4"):
d.Net += "4"
case c.Bool("6"):
d.Net += "6"
}
// Add TLS, if requested
if c.Bool("tls") {
d.Net += "-tls"
}
in, resp.Answers.RTT, err = d.Exchange(msg, resp.Answers.Server)
if err != nil {
return err
}
// If UDP truncates, use TCP instead (unless truncation is to be ignored)
if in.MsgHdr.Truncated && !c.Bool("no-truncate") {
fmt.Printf(";; Truncated, retrying with TCP\n\n")
d.Net = "tcp"
switch {
case c.Bool("4"):
d.Net += "4"
case c.Bool("6"):
d.Net += "6"
}
in, resp.Answers.RTT, err = d.Exchange(msg, resp.Answers.Server)
}
}
if err != nil {
return err
}
if c.Bool("json") {
json, err := json.MarshalIndent(in, "", " ")
if err != nil {
return err
}
fmt.Println(string(json))
} else {
if !c.Bool("short") {
// Print everything
fmt.Println(in)
fmt.Println(";; Query time:", resp.Answers.RTT)
fmt.Println(";; SERVER:", resp.Answers.Server)
fmt.Println(";; WHEN:", time.Now().Format(time.RFC1123Z))
fmt.Println(";; MSG SIZE rcvd:", in.Len())
} else {
// Print just the responses, nothing else
for _, res := range in.Answer {
temp := strings.Split(res.String(), "\t")
fmt.Println(temp[len(temp)-1])
}
}
}
return nil
}
// SPDX-License-Identifier: BSD-3-Clause
package query
import (
"bytes"
"fmt"
"io"
"net/http"
"time"
"github.com/miekg/dns"
)
// Resolve a DNS-over-HTTPS query
//
// Currently only supports POST requests
func ResolveHTTPS(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
httpR := &http.Client{}
buf, err := msg.Pack()
if err != nil {
return nil, 0, err
}
// query := server + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest("POST", server, bytes.NewBuffer(buf))
if err != nil {
return nil, 0, fmt.Errorf("DoH: %s", err.Error())
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
now := time.Now()
res, err := httpR.Do(req)
rtt := time.Since(now)
if err != nil {
return nil, 0, fmt.Errorf("DoH HTTP request error: %s", err.Error())
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("DoH server responded with HTTP %d", res.StatusCode)
}
fullRes, err := io.ReadAll(res.Body)
if err != nil {
return nil, 0, fmt.Errorf("DoH body read error: %s", err.Error())
}
response := dns.Msg{}
err = response.Unpack(fullRes)
if err != nil {
return nil, 0, fmt.Errorf("DoH dns message unpack error: %s", err.Error())
}
return &response, rtt, nil
}
// SPDX-License-Identifier: BSD-3-Clause
package query
import (
"crypto/tls"
"io"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/miekg/dns"
)
// Resolve DNS over QUIC, the hip new standard (for privacy I think, IDK)
func ResolveQUIC(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
tls := &tls.Config{
NextProtos: []string{"doq"},
}
connection, err := quic.DialAddr(server, tls, nil)
if err != nil {
return nil, 0, err
}
// Compress request to over-the-wire
buf, err := msg.Pack()
if err != nil {
return nil, 0, err
}
t := time.Now()
stream, err := connection.OpenStream()
if err != nil {
return nil, 0, err
}
_, err = stream.Write(buf)
if err != nil {
return nil, 0, err
}
fullRes, err := io.ReadAll(stream)
if err != nil {
return nil, 0, err
}
rtt := time.Since(t)
// Close with error: no error
err = connection.CloseWithError(0, "")
if err != nil {
return nil, 0, err
}
err = stream.Close()
if err != nil {
return nil, 0, err
}
response := dns.Msg{}
err = response.Unpack(fullRes)
if err != nil {
return nil, 0, err
}
return &response, rtt, nil
}
// SPDX-License-Identifier: BSD-3-Clause
package util
import (
"errors"
"fmt"
"strings"
"time"
"github.com/miekg/dns"
)
type Response struct {
Answers Answers `json:"Response"` // These be DNS query answers
}
// The Answers struct is the basic structure of a DNS request
// to be returned to the user upon making a request
type Answers struct {
Server string `json:"Server"` // The server to make the DNS request from
Request uint16 `json:"Request"` // The type of request
Name string `json:"Name"` // The domain name to make a DNS request for
RTT time.Duration `json:"RTT"` // The time it took to make the DNS query
}
// Given an IP or phone number, return a canonical string to be queried
func ReverseDNS(address string, query string) (string, error) {
if query == "PTR" {
return dns.ReverseAddr(address)
} else if query == "NAPTR" {
// get rid of characters not needed
replacer := strings.NewReplacer("+", "", " ", "", "-", "")
address = replacer.Replace(address)
// reverse the order of the string
address = reverse(address)
var arpa strings.Builder
// Make it canonical
for _, c := range address {
fmt.Fprintf(&arpa, "%c.", c)
}
arpa.WriteString("e164.arpa.")
return arpa.String(), nil
}
return "", errors.New("ReverseDNS: -x flag given but no IP found")
}
// Reverse a string, return the string in reverse
func reverse(s string) string {
rns := []rune(s)
for i, j := 0, len(rns)-1; i < j; i, j = i+1, j-1 {
rns[i], rns[j] = rns[j], rns[i]
}
return string(rns)
}