internal/devserver/livereload.go

package devserver

import (
	"fmt"
	"net/http"
	"sync"
	"time"
)

// LiveReload is a very small Server-Sent Events endpoint. Browsers listen
// for "reload" events and re-fetch the page when one arrives. I avoided
// websockets on purpose: SSE has no handshake complexity and tolerates
// proxies better, which matters when working over ssh port forwards.
type LiveReload struct {
	mu     sync.RWMutex
	peers  map[chan string]struct{}
	Period time.Duration
}

// NewLiveReload returns an initialized broadcaster.
func NewLiveReload() *LiveReload {
	return &LiveReload{
		peers:  make(map[chan string]struct{}),
		Period: 15 * time.Second,
	}
}

// ServeHTTP registers the requester as a listener and streams events
// until the connection closes.
func (lr *LiveReload) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	flusher, ok := w.(http.Flusher)
	if !ok {
		http.Error(w, "streaming unsupported", http.StatusInternalServerError)
		return
	}
	w.Header().Set("Content-Type", "text/event-stream")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")

	ch := make(chan string, 4)
	lr.register(ch)
	defer lr.deregister(ch)

	heartbeat := time.NewTicker(lr.Period)
	defer heartbeat.Stop()

	fmt.Fprintf(w, "retry: 1000\n\n")
	flusher.Flush()

	for {
		select {
		case <-r.Context().Done():
			return
		case msg := <-ch:
			fmt.Fprintf(w, "event: reload\ndata: %s\n\n", msg)
			flusher.Flush()
		case <-heartbeat.C:
			fmt.Fprint(w, ": heartbeat\n\n")
			flusher.Flush()
		}
	}
}

// Broadcast sends a reload message to every listener. Non-blocking: a
// slow listener gets a dropped message rather than stalling the sender.
func (lr *LiveReload) Broadcast() {
	lr.BroadcastMessage("reload")
}

// BroadcastMessage lets callers send an arbitrary payload.
func (lr *LiveReload) BroadcastMessage(msg string) {
	lr.mu.RLock()
	defer lr.mu.RUnlock()
	for ch := range lr.peers {
		select {
		case ch <- msg:
		default:
		}
	}
}

// Peers returns the current number of subscribed listeners.
func (lr *LiveReload) Peers() int {
	lr.mu.RLock()
	defer lr.mu.RUnlock()
	return len(lr.peers)
}

func (lr *LiveReload) register(ch chan string) {
	lr.mu.Lock()
	defer lr.mu.Unlock()
	lr.peers[ch] = struct{}{}
}

func (lr *LiveReload) deregister(ch chan string) {
	lr.mu.Lock()
	defer lr.mu.Unlock()
	delete(lr.peers, ch)
	close(ch)
}