lotus/lotuspond/outmux.go

125 lines
2.2 KiB
Go
Raw Normal View History

2019-08-19 20:16:27 +00:00
package main
import (
"fmt"
"github.com/gorilla/websocket"
"github.com/opentracing/opentracing-go/log"
"io"
"net/http"
"strings"
"sync"
)
type outmux struct {
lk sync.Mutex
errpw *io.PipeWriter
outpw *io.PipeWriter
errpr *io.PipeReader
outpr *io.PipeReader
2019-08-19 22:38:32 +00:00
n uint64
2019-08-19 20:16:27 +00:00
outs map[uint64]*websocket.Conn
2019-08-19 22:38:32 +00:00
new chan *websocket.Conn
2019-08-19 20:16:27 +00:00
stop chan struct{}
}
func newWsMux() *outmux {
out := &outmux{
2019-08-19 22:38:32 +00:00
n: 0,
2019-08-19 20:16:27 +00:00
outs: map[uint64]*websocket.Conn{},
2019-08-19 22:38:32 +00:00
new: make(chan *websocket.Conn),
stop: make(chan struct{}),
2019-08-19 20:16:27 +00:00
}
out.outpr, out.outpw = io.Pipe()
out.errpr, out.errpw = io.Pipe()
go out.run()
return out
}
func (m *outmux) msgsToChan(r *io.PipeReader, ch chan []byte) {
defer close(ch)
for {
buf := make([]byte, 1)
n, err := r.Read(buf)
if err != nil {
return
}
select {
case ch <- buf[:n]:
case <-m.stop:
return
}
}
}
func (m *outmux) run() {
stdout := make(chan []byte)
stderr := make(chan []byte)
go m.msgsToChan(m.outpr, stdout)
go m.msgsToChan(m.errpr, stderr)
for {
select {
case msg := <-stdout:
for k, out := range m.outs {
if err := out.WriteMessage(websocket.BinaryMessage, msg); err != nil {
out.Close()
fmt.Printf("outmux write failed: %s\n", err)
delete(m.outs, k)
}
}
case msg := <-stderr:
for k, out := range m.outs {
if err := out.WriteMessage(websocket.BinaryMessage, msg); err != nil {
out.Close()
fmt.Printf("outmux write failed: %s\n", err)
delete(m.outs, k)
}
}
case c := <-m.new:
m.n++
m.outs[m.n] = c
case <-m.stop:
for _, out := range m.outs {
out.Close()
}
return
}
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func (m *outmux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Connection"), "Upgrade") {
fmt.Println("noupgrade")
w.WriteHeader(500)
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Header.Get("Sec-WebSocket-Protocol") != "" {
w.Header().Set("Sec-WebSocket-Protocol", r.Header.Get("Sec-WebSocket-Protocol"))
}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error(err)
w.WriteHeader(500)
return
}
m.new <- c
return
2019-08-19 22:38:32 +00:00
}