From 544b3f73217522ff28c6f4ed99470290f60749d4 Mon Sep 17 00:00:00 2001
From: Moritz Marquardt <git@momar.de>
Date: Wed, 1 Dec 2021 22:49:48 +0100
Subject: [PATCH] (Ab)use CSR field to store try-again date for renewals
 (instead of showing a mock cert), must be tested when the first renewals are
 due

---
 certificates.go | 92 ++++++++++++++++++++-----------------------------
 helpers.go      | 37 +++++++++++++++++++-
 2 files changed, 73 insertions(+), 56 deletions(-)

diff --git a/certificates.go b/certificates.go
index f85109d..36f0df0 100644
--- a/certificates.go
+++ b/certificates.go
@@ -24,6 +24,7 @@ import (
 	"log"
 	"math/big"
 	"os"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -207,21 +208,9 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
 
 func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) {
 	// parse certificate from database
-	resBytes, err := keyDatabase.Get(sni)
-	if err != nil {
-		// key database is not working
-		panic(err)
-	}
-	if resBytes == nil {
-		return tls.Certificate{}, false
-	}
-
-	resGob := bytes.NewBuffer(resBytes)
-	resDec := gob.NewDecoder(resGob)
 	res := &certificate.Resource{}
-	err = resDec.Decode(res)
-	if err != nil {
-		panic(err)
+	if !PogrebGet(keyDatabase, sni, res) {
+		return tls.Certificate{}, false
 	}
 
 	tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
@@ -237,7 +226,15 @@ func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) {
 
 		// renew certificates 7 days before they expire
 		if !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-7 * 24 * time.Hour)) {
+			if res.CSR != nil && len(res.CSR) > 0 {
+				// CSR stores the time when the renewal shall be tried again
+				nextTryUnix, err := strconv.ParseInt(string(res.CSR), 10, 64)
+				if err == nil && time.Now().Before(time.Unix(nextTryUnix, 0)) {
+					return tlsCertificate, true
+				}
+			}
 			go (func() {
+				res.CSR = nil // acme client doesn't like CSR to be set
 				tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "")
 				if err != nil {
 					log.Printf("Couldn't renew certificate for %s: %s", sni, err)
@@ -310,18 +307,21 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
 	}
 	if err != nil {
 		log.Printf("Couldn't obtain certificate for %v: %s", domains, err)
-		return mockCert(domains[0], err.Error()), err
+		if renew != nil && renew.CertURL != "" {
+			tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey)
+			if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) {
+				// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
+				renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6 * time.Hour).Unix(), 10))
+				PogrebPut(keyDatabase, []byte(name), renew)
+				return tlsCertificate, nil
+			}
+		} else {
+			return mockCert(domains[0], err.Error()), err
+		}
 	}
 	log.Printf("Obtained certificate for %v", domains)
 
-	var resGob bytes.Buffer
-	resEnc := gob.NewEncoder(&resGob)
-	err = resEnc.Encode(res)
-	if err != nil {
-		panic(err)
-	}
-	err = keyDatabase.Put([]byte(name), resGob.Bytes())
-
+	PogrebPut(keyDatabase, []byte(name), res)
 	tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
 	if err != nil {
 		return tls.Certificate{}, err
@@ -382,20 +382,11 @@ func mockCert(domain string, msg string) tls.Certificate {
 		IssuerCertificate: outBytes,
 		Domain: domain,
 	}
-	var resGob bytes.Buffer
-	resEnc := gob.NewEncoder(&resGob)
-	err = resEnc.Encode(res)
-	if err != nil {
-		panic(err)
-	}
 	databaseName := domain
 	if domain == "*" + string(MainDomainSuffix) || domain == string(MainDomainSuffix[1:]) {
 		databaseName = string(MainDomainSuffix)
 	}
-	err = keyDatabase.Put([]byte(databaseName), resGob.Bytes())
-	if err != nil {
-		panic(err)
-	}
+	PogrebPut(keyDatabase, []byte(databaseName), res)
 
 	tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
 	if err != nil {
@@ -585,30 +576,21 @@ func setupCertificates() {
 			}
 
 			// update main cert
-			resBytes, err = keyDatabase.Get(MainDomainSuffix)
-			if err != nil {
-				// key database is not working
-				panic(err)
-			}
-
-			resGob := bytes.NewBuffer(resBytes)
-			resDec := gob.NewDecoder(resGob)
 			res := &certificate.Resource{}
-			err = resDec.Decode(res)
-			if err != nil {
-				panic(err)
-			}
+			if !PogrebGet(keyDatabase, MainDomainSuffix, res) {
+				log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted")
+			} else {
+				tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
 
-			tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
-
-			// renew main certificate 30 days before it expires
-			if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) {
-				go (func() {
-					_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])}, res, "")
-					if err != nil {
-						log.Printf("Couldn't renew certificate for *%s: %s", MainDomainSuffix, err)
-					}
-				})()
+				// renew main certificate 30 days before it expires
+				if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) {
+					go (func() {
+						_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])}, res, "")
+						if err != nil {
+							log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err)
+						}
+					})()
+				}
 			}
 
 			time.Sleep(12 * time.Hour)
diff --git a/helpers.go b/helpers.go
index 506573b..46a1492 100644
--- a/helpers.go
+++ b/helpers.go
@@ -1,6 +1,10 @@
 package main
 
-import "bytes"
+import (
+	"bytes"
+	"encoding/gob"
+	"github.com/akrylysov/pogreb"
+)
 
 // GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty
 // string for custom domains.
@@ -19,3 +23,34 @@ func TrimHostPort(host []byte) []byte {
 	}
 	return host
 }
+
+func PogrebPut(db *pogreb.DB, name []byte, obj interface{}) {
+	var resGob bytes.Buffer
+	resEnc := gob.NewEncoder(&resGob)
+	err := resEnc.Encode(obj)
+	if err != nil {
+		panic(err)
+	}
+	err = db.Put(name, resGob.Bytes())
+	if err != nil {
+		panic(err)
+	}
+}
+
+func PogrebGet(db *pogreb.DB, name []byte, obj interface{}) bool {
+	resBytes, err := db.Get(name)
+	if err != nil {
+		panic(err)
+	}
+	if resBytes == nil {
+		return false
+	}
+
+	resGob := bytes.NewBuffer(resBytes)
+	resDec := gob.NewDecoder(resGob)
+	err = resDec.Decode(obj)
+	if err != nil {
+		panic(err)
+	}
+	return true
+}