Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions cacher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package godns

import (
"encoding/binary"
"fmt"
"io"
"log"
"os"
"path/filepath"

"github.com/panjf2000/ants/v2"
"github.com/tochusc/godns/dns"
)

type Cacher struct {
CacheLocation string
CacherLogger *log.Logger
CacherPool *ants.Pool
}

type CacherConfig struct {
CacheLocation string
LogWriter io.Writer
}

func NewCacher(conf CacherConfig, pool *ants.Pool) *Cacher {
cacherLogger := log.New(conf.LogWriter, "Cacher: ", log.LstdFlags)

return &Cacher{
CacheLocation: conf.CacheLocation,
CacherLogger: cacherLogger,
CacherPool: pool,
}
}

func (c *Cacher) CacheResponse(data []byte) error {
ident, err := IdentifyMessage(data)
if err != nil {
c.CacherLogger.Printf("Error identifying response: %v", err)
return err
}

path := filepath.Join(c.CacheLocation, ident)

// 将响应缓存到磁盘

// 如果缓存目录不存在,创建目录
if _, err := os.Stat(c.CacheLocation); os.IsNotExist(err) {
err := os.MkdirAll(c.CacheLocation, 0755)
if err != nil {
c.CacherLogger.Printf("Error creating cache directory %s: %v", c.CacheLocation, err)
return err
}
}

// 创建缓存文件
file, err := os.Create(path)
if err != nil {
c.CacherLogger.Printf("Error creating cache file %s: %v", ident, err)
return err
}

_, err = file.Write(data)
if err != nil {
c.CacherLogger.Printf("Error writing cache file %s: %v", ident, err)
return err
}

file.Close()
c.CacherLogger.Printf("Cache saved %s\n", ident)

return nil
}

func (c *Cacher) FetchCache(connInfo ConnectionInfo) ([]byte, error) {

ident, err := IdentifyMessage(connInfo.Packet)
if err != nil {
c.CacherLogger.Printf("Error identifying response: %v", err)
return []byte{}, err
}

path := filepath.Join(c.CacheLocation, ident)

file, err := os.Open(path)
if err != nil {
c.CacherLogger.Printf("Cache miss %s\n", ident)
return []byte{}, err
}
defer file.Close()

cache := make([]byte, 65535)
rd, err := file.Read(cache)
if err != nil {
c.CacherLogger.Printf("Error reading cache file %s: %v", ident, err)
return []byte{}, err
}

c.CacherLogger.Printf("Cache hit %s\n", ident)

// 修改Cache内容
cache[0] = connInfo.Packet[0]
cache[1] = connInfo.Packet[1]
for i := 0; ; i++ {
cache[12+i] = connInfo.Packet[12+i]

if cache[12+i] > dns.NamePointerFlag {
cache[13+i] = connInfo.Packet[13+i]
break
}
if cache[12+i] == 0x00 {
break
}
}

return cache[:rd], nil
}

func IdentifyMessage(data []byte) (string, error) {
// 解析 DNS 请求
qName, offset, err := dns.DecodeDomainNameFromBuffer(data, 12)
if err != nil {
return "", err
}
qType := dns.DNSType(binary.BigEndian.Uint16(data[offset : offset+2]))
qClass := dns.DNSClass(binary.BigEndian.Uint16(data[offset+2 : offset+4]))
return fmt.Sprintf("%s-%s-%s", qName, qType.String(), qClass.String()), nil
}
10 changes: 5 additions & 5 deletions dns/xperi/dnssec.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func GenerateRandomRRRRSIG(rrSet []dns.DNSResourceRecord, algo dns.DNSSECAlgorit
return rr
}

func GenerateRandomRDATADS(oName string, kRDATA dns.DNSRDATADNSKEY, dType dns.DNSSECDigestType) dns.DNSRDATADS {
func GenerateRandomRDATADS(oName string, keytag int, algo dns.DNSSECAlgorithm, dType dns.DNSSECDigestType) dns.DNSRDATADS {
rText := []byte(GenerateRandomString(96))
var digest []byte
switch dType {
Expand All @@ -442,15 +442,15 @@ func GenerateRandomRDATADS(oName string, kRDATA dns.DNSRDATADNSKEY, dType dns.DN

// 4. 构建 DS RDATA
return dns.DNSRDATADS{
KeyTag: CalculateKeyTag(kRDATA),
Algorithm: kRDATA.Algorithm,
KeyTag: uint16(keytag),
Algorithm: algo,
DigestType: dType,
Digest: digest[:],
}
}

func GenerateRandomRRDS(oName string, kRDATA dns.DNSRDATADNSKEY, dType dns.DNSSECDigestType) dns.DNSResourceRecord {
rdata := GenerateRandomRDATADS(oName, kRDATA, dType)
func GenerateRandomRRDS(oName string, keytag int, algo dns.DNSSECAlgorithm, dType dns.DNSSECDigestType) dns.DNSResourceRecord {
rdata := GenerateRandomRDATADS(oName, keytag, algo, dType)
rr := dns.DNSResourceRecord{
Name: oName,
Type: dns.DNSRRTypeDS,
Expand Down
Loading