1
0
mirror of https://github.com/makew0rld/amfora.git synced 2025-01-03 14:56:27 -05:00

🐛 Prevent concurrent map panic for tofuStore

This commit is contained in:
makeworld 2020-12-09 15:17:07 -05:00
parent 4a2c7da529
commit 1735f1c53c

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"sync"
"time" "time"
"github.com/makeworld-the-better-one/amfora/config" "github.com/makeworld-the-better-one/amfora/config"
@ -21,6 +22,12 @@ var ErrTofu = errors.New("server cert does not match TOFU database")
var tofuStore = config.TofuStore var tofuStore = config.TofuStore
// tofuStoreMu protects tofuStore, since viper is not thread-safe.
// See this issue for details: https://github.com/spf13/viper/issues/268
// This is needed because Gemini requests may happen concurrently and
// call on the funcs on this file.
var tofuStoreMu = sync.RWMutex{}
// idKey returns the config/viper key needed to retrieve // idKey returns the config/viper key needed to retrieve
// a cert's ID / fingerprint. // a cert's ID / fingerprint.
func idKey(domain string, port string) string { func idKey(domain string, port string) string {
@ -38,6 +45,9 @@ func expiryKey(domain string, port string) string {
} }
func loadTofuEntry(domain string, port string) (string, time.Time, error) { func loadTofuEntry(domain string, port string) (string, time.Time, error) {
tofuStoreMu.RLock()
defer tofuStoreMu.RUnlock()
id := tofuStore.GetString(idKey(domain, port)) // Fingerprint id := tofuStore.GetString(idKey(domain, port)) // Fingerprint
if len(id) != sha256.Size*2 { if len(id) != sha256.Size*2 {
// Not set, or invalid // Not set, or invalid
@ -68,6 +78,9 @@ func origCertID(cert *x509.Certificate) string {
} }
func saveTofuEntry(domain, port string, cert *x509.Certificate) { func saveTofuEntry(domain, port string, cert *x509.Certificate) {
tofuStoreMu.Lock()
defer tofuStoreMu.Unlock()
tofuStore.Set(idKey(domain, port), certID(cert)) tofuStore.Set(idKey(domain, port), certID(cert))
tofuStore.Set(expiryKey(domain, port), cert.NotAfter.UTC()) tofuStore.Set(expiryKey(domain, port), cert.NotAfter.UTC())
tofuStore.WriteConfig() //nolint:errcheck // Not an issue if it's not saved, only cached data tofuStore.WriteConfig() //nolint:errcheck // Not an issue if it's not saved, only cached data
@ -90,9 +103,10 @@ func handleTofu(domain, port string, cert *x509.Certificate) bool {
// Same cert as the one stored // Same cert as the one stored
// Store expiry again in case it changed // Store expiry again in case it changed
tofuStoreMu.Lock()
tofuStore.Set(expiryKey(domain, port), cert.NotAfter.UTC()) tofuStore.Set(expiryKey(domain, port), cert.NotAfter.UTC())
tofuStore.WriteConfig() //nolint:errcheck tofuStore.WriteConfig() //nolint:errcheck
tofuStoreMu.Unlock()
return true return true
} }
if origCertID(cert) == id { if origCertID(cert) == id {
@ -117,5 +131,8 @@ func ResetTofuEntry(domain, port string, cert *x509.Certificate) {
// GetExpiry returns the stored expiry date for the given host. // GetExpiry returns the stored expiry date for the given host.
// The time will be empty (zero) if there is not expiry date stored for that host. // The time will be empty (zero) if there is not expiry date stored for that host.
func GetExpiry(domain, port string) time.Time { func GetExpiry(domain, port string) time.Time {
tofuStoreMu.RLock()
defer tofuStoreMu.RUnlock()
return tofuStore.GetTime(expiryKey(domain, port)) return tofuStore.GetTime(expiryKey(domain, port))
} }