From 5ca5020cfa78811c4784f405dd633791c46ac081 Mon Sep 17 00:00:00 2001
From: 6543 <6543@obermui.de>
Date: Fri, 3 Dec 2021 04:15:48 +0100
Subject: [PATCH] open key-database deterministic

---
 cmd/certs.go                 | 14 +++++---
 cmd/main.go                  | 14 ++++++--
 server/cache/interface.go    |  8 +++++
 server/cache/setup.go        |  7 ++++
 server/certificates.go       | 63 ++++++++++++++----------------------
 server/database/helpers.go   |  5 ++-
 server/database/interface.go | 12 +++++++
 server/database/setup.go     | 19 +++++++++++
 8 files changed, 94 insertions(+), 48 deletions(-)
 create mode 100644 server/cache/interface.go
 create mode 100644 server/cache/setup.go
 create mode 100644 server/database/interface.go
 create mode 100644 server/database/setup.go

diff --git a/cmd/certs.go b/cmd/certs.go
index 4676520..89521da 100644
--- a/cmd/certs.go
+++ b/cmd/certs.go
@@ -1,11 +1,12 @@
 package cmd
 
 import (
+	"fmt"
 	"os"
 
 	"github.com/urfave/cli/v2"
 
-	pages_server "codeberg.org/codeberg/pages/server"
+	"codeberg.org/codeberg/pages/server/database"
 )
 
 var Certs = &cli.Command{
@@ -23,15 +24,18 @@ func certs(ctx *cli.Context) error {
 
 		domains := ctx.Args().Slice()[2:]
 
-		if pages_server.KeyDatabaseErr != nil {
-			panic(pages_server.KeyDatabaseErr)
+		// TODO: make "key-database.pogreb" set via flag
+		keyDatabase, err := database.New("key-database.pogreb")
+		if err != nil {
+			return fmt.Errorf("could not create database: %v", err)
 		}
+
 		for _, domain := range domains {
-			if err := pages_server.KeyDatabase.Delete([]byte(domain)); err != nil {
+			if err := keyDatabase.Delete([]byte(domain)); err != nil {
 				panic(err)
 			}
 		}
-		if err := pages_server.KeyDatabase.Sync(); err != nil {
+		if err := keyDatabase.Sync(); err != nil {
 			panic(err)
 		}
 		os.Exit(0)
diff --git a/cmd/main.go b/cmd/main.go
index a7d606d..2f3d7ac 100644
--- a/cmd/main.go
+++ b/cmd/main.go
@@ -15,6 +15,8 @@ import (
 	"github.com/valyala/fasthttp"
 
 	"codeberg.org/codeberg/pages/server"
+	"codeberg.org/codeberg/pages/server/cache"
+	"codeberg.org/codeberg/pages/server/database"
 	"codeberg.org/codeberg/pages/server/utils"
 )
 
@@ -84,9 +86,17 @@ func Serve(ctx *cli.Context) error {
 	if err != nil {
 		return fmt.Errorf("couldn't create listener: %s", err)
 	}
-	listener = tls.NewListener(listener, server.TlsConfig(mainDomainSuffix, giteaRoot, giteaAPIToken, dnsProvider, acmeUseRateLimits))
 
-	server.SetupCertificates(mainDomainSuffix, acmeAPI, acmeMail, acmeEabHmac, acmeEabKID, dnsProvider, acmeUseRateLimits, acmeAcceptTerms, enableHTTPServer)
+	// TODO: make "key-database.pogreb" set via flag
+	keyDatabase, err := database.New("key-database.pogreb")
+	if err != nil {
+		return fmt.Errorf("could not create database: %v", err)
+	}
+
+	keyCache := cache.NewKeyValueCache()
+	listener = tls.NewListener(listener, server.TLSConfig(mainDomainSuffix, giteaRoot, giteaAPIToken, dnsProvider, acmeUseRateLimits, keyCache, keyDatabase))
+
+	server.SetupCertificates(mainDomainSuffix, acmeAPI, acmeMail, acmeEabHmac, acmeEabKID, dnsProvider, acmeUseRateLimits, acmeAcceptTerms, enableHTTPServer, keyDatabase)
 	if enableHTTPServer {
 		go (func() {
 			challengePath := []byte("/.well-known/acme-challenge/")
diff --git a/server/cache/interface.go b/server/cache/interface.go
new file mode 100644
index 0000000..37ae8f5
--- /dev/null
+++ b/server/cache/interface.go
@@ -0,0 +1,8 @@
+package cache
+
+import "time"
+
+type SetGetKey interface {
+	Set(key string, value interface{}, ttl time.Duration) error
+	Get(key string) (interface{}, bool)
+}
diff --git a/server/cache/setup.go b/server/cache/setup.go
new file mode 100644
index 0000000..a5928b0
--- /dev/null
+++ b/server/cache/setup.go
@@ -0,0 +1,7 @@
+package cache
+
+import "github.com/OrlovEvgeny/go-mcache"
+
+func NewKeyValueCache() SetGetKey {
+	return mcache.New()
+}
diff --git a/server/certificates.go b/server/certificates.go
index 5339375..d6b6b86 100644
--- a/server/certificates.go
+++ b/server/certificates.go
@@ -24,8 +24,6 @@ import (
 	"time"
 
 	"github.com/OrlovEvgeny/go-mcache"
-	"github.com/akrylysov/pogreb"
-	"github.com/akrylysov/pogreb/fs"
 	"github.com/reugn/equalizer"
 
 	"github.com/go-acme/lego/v4/certcrypto"
@@ -36,11 +34,12 @@ import (
 	"github.com/go-acme/lego/v4/providers/dns"
 	"github.com/go-acme/lego/v4/registration"
 
+	"codeberg.org/codeberg/pages/server/cache"
 	"codeberg.org/codeberg/pages/server/database"
 )
 
-// TlsConfig returns the configuration for generating, serving and cleaning up Let's Encrypt certificates.
-func TlsConfig(mainDomainSuffix []byte, giteaRoot, giteaApiToken, dnsProvider string, acmeUseRateLimits bool) *tls.Config {
+// TLSConfig returns the configuration for generating, serving and cleaning up Let's Encrypt certificates.
+func TLSConfig(mainDomainSuffix []byte, giteaRoot, giteaApiToken, dnsProvider string, acmeUseRateLimits bool, keyCache cache.SetGetKey, keyDatabase database.KeyDB) *tls.Config {
 	return &tls.Config{
 		// check DNS name & get certificate from Let's Encrypt
 		GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -96,13 +95,13 @@ func TlsConfig(mainDomainSuffix []byte, giteaRoot, giteaApiToken, dnsProvider st
 			var tlsCertificate tls.Certificate
 			var err error
 			var ok bool
-			if tlsCertificate, ok = retrieveCertFromDB(sniBytes, mainDomainSuffix, dnsProvider, acmeUseRateLimits); !ok {
+			if tlsCertificate, ok = retrieveCertFromDB(sniBytes, mainDomainSuffix, dnsProvider, acmeUseRateLimits, keyDatabase); !ok {
 				// request a new certificate
 				if bytes.Equal(sniBytes, mainDomainSuffix) {
 					return nil, errors.New("won't request certificate for main domain, something really bad has happened")
 				}
 
-				tlsCertificate, err = obtainCert(acmeClient, []string{sni}, nil, targetOwner, dnsProvider, mainDomainSuffix, acmeUseRateLimits)
+				tlsCertificate, err = obtainCert(acmeClient, []string{sni}, nil, targetOwner, dnsProvider, mainDomainSuffix, acmeUseRateLimits, keyDatabase)
 				if err != nil {
 					return nil, err
 				}
@@ -134,14 +133,6 @@ func TlsConfig(mainDomainSuffix []byte, giteaRoot, giteaApiToken, dnsProvider st
 	}
 }
 
-// TODO: clean up & move to init
-var keyCache = mcache.New()
-var KeyDatabase, KeyDatabaseErr = pogreb.Open("key-database.pogreb", &pogreb.Options{
-	BackgroundSyncInterval:       30 * time.Second,
-	BackgroundCompactionInterval: 6 * time.Hour,
-	FileSystem:                   fs.OSMMap,
-})
-
 func CheckUserLimit(user string) error {
 	userLimit, ok := acmeClientCertificateLimitPerUser[user]
 	if !ok {
@@ -211,10 +202,10 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
 	return nil
 }
 
-func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool) (tls.Certificate, bool) {
+func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, keyDatabase database.KeyDB) (tls.Certificate, bool) {
 	// parse certificate from database
 	res := &certificate.Resource{}
-	if !database.PogrebGet(KeyDatabase, sni, res) {
+	if !database.PogrebGet(keyDatabase, sni, res) {
 		return tls.Certificate{}, false
 	}
 
@@ -242,7 +233,7 @@ func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUs
 			}
 			go (func() {
 				res.CSR = nil // acme client doesn't like CSR to be set
-				tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits)
+				tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, keyDatabase)
 				if err != nil {
 					log.Printf("Couldn't renew certificate for %s: %s", sni, err)
 				}
@@ -255,7 +246,7 @@ func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUs
 
 var obtainLocks = sync.Map{}
 
-func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Resource, user, dnsProvider string, mainDomainSuffix []byte, acmeUseRateLimits bool) (tls.Certificate, error) {
+func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Resource, user, dnsProvider string, mainDomainSuffix []byte, acmeUseRateLimits bool, keyDatabase database.KeyDB) (tls.Certificate, error) {
 	name := strings.TrimPrefix(domains[0], "*")
 	if dnsProvider == "" && len(domains[0]) > 0 && domains[0][0] == '*' {
 		domains = domains[1:]
@@ -268,7 +259,7 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
 			time.Sleep(100 * time.Millisecond)
 			_, working = obtainLocks.Load(name)
 		}
-		cert, ok := retrieveCertFromDB([]byte(name), mainDomainSuffix, dnsProvider, acmeUseRateLimits)
+		cert, ok := retrieveCertFromDB([]byte(name), mainDomainSuffix, dnsProvider, acmeUseRateLimits, keyDatabase)
 		if !ok {
 			return tls.Certificate{}, errors.New("certificate failed in synchronous request")
 		}
@@ -277,7 +268,7 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
 	defer obtainLocks.Delete(name)
 
 	if acmeClient == nil {
-		return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", string(mainDomainSuffix)), nil
+		return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", string(mainDomainSuffix), keyDatabase), nil
 	}
 
 	// request actual cert
@@ -319,15 +310,15 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
 			if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) {
 				// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
 				renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10))
-				database.PogrebPut(KeyDatabase, []byte(name), renew)
+				database.PogrebPut(keyDatabase, []byte(name), renew)
 				return tlsCertificate, nil
 			}
 		}
-		return mockCert(domains[0], err.Error(), string(mainDomainSuffix)), err
+		return mockCert(domains[0], err.Error(), string(mainDomainSuffix), keyDatabase), err
 	}
 	log.Printf("Obtained certificate for %v", domains)
 
-	database.PogrebPut(KeyDatabase, []byte(name), res)
+	database.PogrebPut(keyDatabase, []byte(name), res)
 	tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
 	if err != nil {
 		return tls.Certificate{}, err
@@ -335,7 +326,7 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
 	return tlsCertificate, nil
 }
 
-func mockCert(domain, msg, mainDomainSuffix string) tls.Certificate {
+func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.KeyDB) tls.Certificate {
 	key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048)
 	if err != nil {
 		panic(err)
@@ -392,7 +383,7 @@ func mockCert(domain, msg, mainDomainSuffix string) tls.Certificate {
 	if domain == "*"+mainDomainSuffix || domain == mainDomainSuffix[1:] {
 		databaseName = mainDomainSuffix
 	}
-	database.PogrebPut(KeyDatabase, []byte(databaseName), res)
+	database.PogrebPut(keyDatabase, []byte(databaseName), res)
 
 	tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
 	if err != nil {
@@ -401,13 +392,9 @@ func mockCert(domain, msg, mainDomainSuffix string) tls.Certificate {
 	return tlsCertificate
 }
 
-func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac, acmeEabKID, dnsProvider string, acmeUseRateLimits, acmeAcceptTerms, enableHTTPServer bool) {
-	if KeyDatabaseErr != nil {
-		panic(KeyDatabaseErr) // TODO: move it into own init and not panic on a unrelated topic!!!!
-	}
-
+func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac, acmeEabKID, dnsProvider string, acmeUseRateLimits, acmeAcceptTerms, enableHTTPServer bool, keyDatabase database.KeyDB) {
 	// getting main cert before ACME account so that we can panic here on database failure without hitting rate limits
-	mainCertBytes, err := KeyDatabase.Get(mainDomainSuffix)
+	mainCertBytes, err := keyDatabase.Get(mainDomainSuffix)
 	if err != nil {
 		// key database is not working
 		panic(err)
@@ -523,7 +510,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 	}
 
 	if mainCertBytes == nil {
-		_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, nil, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits)
+		_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, nil, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, keyDatabase)
 		if err != nil {
 			log.Printf("[ERROR] Couldn't renew main domain certificate, continuing with mock certs only: %s", err)
 		}
@@ -531,7 +518,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 
 	go (func() {
 		for {
-			err := KeyDatabase.Sync()
+			err := keyDatabase.Sync()
 			if err != nil {
 				log.Printf("[ERROR] Syncing key database failed: %s", err)
 			}
@@ -544,7 +531,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 			// clean up expired certs
 			now := time.Now()
 			expiredCertCount := 0
-			keyDatabaseIterator := KeyDatabase.Items()
+			keyDatabaseIterator := keyDatabase.Items()
 			key, resBytes, err := keyDatabaseIterator.Next()
 			for err == nil {
 				if !bytes.Equal(key, mainDomainSuffix) {
@@ -558,7 +545,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 
 					tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
 					if err != nil || !tlsCertificates[0].NotAfter.After(now) {
-						err := KeyDatabase.Delete(key)
+						err := keyDatabase.Delete(key)
 						if err != nil {
 							log.Printf("[ERROR] Deleting expired certificate for %s failed: %s", string(key), err)
 						} else {
@@ -571,7 +558,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 			log.Printf("[INFO] Removed %d expired certificates from the database", expiredCertCount)
 
 			// compact the database
-			result, err := KeyDatabase.Compact()
+			result, err := keyDatabase.Compact()
 			if err != nil {
 				log.Printf("[ERROR] Compacting key database failed: %s", err)
 			} else {
@@ -580,7 +567,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 
 			// update main cert
 			res := &certificate.Resource{}
-			if !database.PogrebGet(KeyDatabase, mainDomainSuffix, res) {
+			if !database.PogrebGet(keyDatabase, mainDomainSuffix, res) {
 				log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted")
 			} else {
 				tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
@@ -588,7 +575,7 @@ func SetupCertificates(mainDomainSuffix []byte, acmeAPI, acmeMail, acmeEabHmac,
 				// renew main certificate 30 days before it expires
 				if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) {
 					go (func() {
-						_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits)
+						_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, keyDatabase)
 						if err != nil {
 							log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err)
 						}
diff --git a/server/database/helpers.go b/server/database/helpers.go
index b2eb017..98ea3fa 100644
--- a/server/database/helpers.go
+++ b/server/database/helpers.go
@@ -3,10 +3,9 @@ package database
 import (
 	"bytes"
 	"encoding/gob"
-	"github.com/akrylysov/pogreb"
 )
 
-func PogrebPut(db *pogreb.DB, name []byte, obj interface{}) {
+func PogrebPut(db KeyDB, name []byte, obj interface{}) {
 	var resGob bytes.Buffer
 	resEnc := gob.NewEncoder(&resGob)
 	err := resEnc.Encode(obj)
@@ -19,7 +18,7 @@ func PogrebPut(db *pogreb.DB, name []byte, obj interface{}) {
 	}
 }
 
-func PogrebGet(db *pogreb.DB, name []byte, obj interface{}) bool {
+func PogrebGet(db KeyDB, name []byte, obj interface{}) bool {
 	resBytes, err := db.Get(name)
 	if err != nil {
 		panic(err)
diff --git a/server/database/interface.go b/server/database/interface.go
new file mode 100644
index 0000000..2b582ae
--- /dev/null
+++ b/server/database/interface.go
@@ -0,0 +1,12 @@
+package database
+
+import "github.com/akrylysov/pogreb"
+
+type KeyDB interface {
+	Sync() error
+	Put(key []byte, value []byte) error
+	Get(key []byte) ([]byte, error)
+	Delete(key []byte) error
+	Compact() (pogreb.CompactionResult, error)
+	Items() *pogreb.ItemIterator
+}
diff --git a/server/database/setup.go b/server/database/setup.go
new file mode 100644
index 0000000..c16ff36
--- /dev/null
+++ b/server/database/setup.go
@@ -0,0 +1,19 @@
+package database
+
+import (
+	"fmt"
+	"github.com/akrylysov/pogreb"
+	"github.com/akrylysov/pogreb/fs"
+	"time"
+)
+
+func New(path string) (KeyDB, error) {
+	if path == "" {
+		return nil, fmt.Errorf("path not set")
+	}
+	return pogreb.Open(path, &pogreb.Options{
+		BackgroundSyncInterval:       30 * time.Second,
+		BackgroundCompactionInterval: 6 * time.Hour,
+		FileSystem:                   fs.OSMMap,
+	})
+}