Merge branch 'dn/test-reject-utf-16'
[git] / contrib / persistent-https / proxy.go
1 // Copyright 2012 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "fmt"
19         "log"
20         "net"
21         "net/http"
22         "net/http/httputil"
23         "os"
24         "os/exec"
25         "os/signal"
26         "sync"
27         "syscall"
28         "time"
29 )
30
31 type Proxy struct {
32         BuildLabel         string
33         MaxIdleDuration    time.Duration
34         PollUpdateInterval time.Duration
35
36         ul        net.Listener
37         httpAddr  string
38         httpsAddr string
39 }
40
41 func (p *Proxy) Run() error {
42         hl, err := net.Listen("tcp", "127.0.0.1:0")
43         if err != nil {
44                 return fmt.Errorf("http listen failed: %v", err)
45         }
46         defer hl.Close()
47
48         hsl, err := net.Listen("tcp", "127.0.0.1:0")
49         if err != nil {
50                 return fmt.Errorf("https listen failed: %v", err)
51         }
52         defer hsl.Close()
53
54         p.ul, err = DefaultSocket.Listen()
55         if err != nil {
56                 c, derr := DefaultSocket.Dial()
57                 if derr == nil {
58                         c.Close()
59                         fmt.Println("OK\nA proxy is already running... exiting")
60                         return nil
61                 } else if e, ok := derr.(*net.OpError); ok && e.Err == syscall.ECONNREFUSED {
62                         // Nothing is listening on the socket, unlink it and try again.
63                         syscall.Unlink(DefaultSocket.Path())
64                         p.ul, err = DefaultSocket.Listen()
65                 }
66                 if err != nil {
67                         return fmt.Errorf("unix listen failed on %v: %v", DefaultSocket.Path(), err)
68                 }
69         }
70         defer p.ul.Close()
71         go p.closeOnSignal()
72         go p.closeOnUpdate()
73
74         p.httpAddr = hl.Addr().String()
75         p.httpsAddr = hsl.Addr().String()
76         fmt.Printf("OK\nListening on unix socket=%v http=%v https=%v\n",
77                 p.ul.Addr(), p.httpAddr, p.httpsAddr)
78
79         result := make(chan error, 2)
80         go p.serveUnix(result)
81         go func() {
82                 result <- http.Serve(hl, &httputil.ReverseProxy{
83                         FlushInterval: 500 * time.Millisecond,
84                         Director:      func(r *http.Request) {},
85                 })
86         }()
87         go func() {
88                 result <- http.Serve(hsl, &httputil.ReverseProxy{
89                         FlushInterval: 500 * time.Millisecond,
90                         Director: func(r *http.Request) {
91                                 r.URL.Scheme = "https"
92                         },
93                 })
94         }()
95         return <-result
96 }
97
98 type socketContext struct {
99         sync.WaitGroup
100         mutex sync.Mutex
101         last  time.Time
102 }
103
104 func (sc *socketContext) Done() {
105         sc.mutex.Lock()
106         defer sc.mutex.Unlock()
107         sc.last = time.Now()
108         sc.WaitGroup.Done()
109 }
110
111 func (p *Proxy) serveUnix(result chan<- error) {
112         sockCtx := &socketContext{}
113         go p.closeOnIdle(sockCtx)
114
115         var err error
116         for {
117                 var uconn net.Conn
118                 uconn, err = p.ul.Accept()
119                 if err != nil {
120                         err = fmt.Errorf("accept failed: %v", err)
121                         break
122                 }
123                 sockCtx.Add(1)
124                 go p.handleUnixConn(sockCtx, uconn)
125         }
126         sockCtx.Wait()
127         result <- err
128 }
129
130 func (p *Proxy) handleUnixConn(sockCtx *socketContext, uconn net.Conn) {
131         defer sockCtx.Done()
132         defer uconn.Close()
133         data := []byte(fmt.Sprintf("%v\n%v", p.httpsAddr, p.httpAddr))
134         uconn.SetDeadline(time.Now().Add(5 * time.Second))
135         for i := 0; i < 2; i++ {
136                 if n, err := uconn.Write(data); err != nil {
137                         log.Printf("error sending http addresses: %+v\n", err)
138                         return
139                 } else if n != len(data) {
140                         log.Printf("sent %d data bytes, wanted %d\n", n, len(data))
141                         return
142                 }
143                 if _, err := uconn.Read([]byte{0, 0, 0, 0}); err != nil {
144                         log.Printf("error waiting for Ack: %+v\n", err)
145                         return
146                 }
147         }
148         // Wait without a deadline for the client to finish via EOF
149         uconn.SetDeadline(time.Time{})
150         uconn.Read([]byte{0, 0, 0, 0})
151 }
152
153 func (p *Proxy) closeOnIdle(sockCtx *socketContext) {
154         for d := p.MaxIdleDuration; d > 0; {
155                 time.Sleep(d)
156                 sockCtx.Wait()
157                 sockCtx.mutex.Lock()
158                 if d = sockCtx.last.Add(p.MaxIdleDuration).Sub(time.Now()); d <= 0 {
159                         log.Println("graceful shutdown from idle timeout")
160                         p.ul.Close()
161                 }
162                 sockCtx.mutex.Unlock()
163         }
164 }
165
166 func (p *Proxy) closeOnUpdate() {
167         for {
168                 time.Sleep(p.PollUpdateInterval)
169                 if out, err := exec.Command(os.Args[0], "--print_label").Output(); err != nil {
170                         log.Printf("error polling for updated binary: %v\n", err)
171                 } else if s := string(out[:len(out)-1]); p.BuildLabel != s {
172                         log.Printf("graceful shutdown from updated binary: %q --> %q\n", p.BuildLabel, s)
173                         p.ul.Close()
174                         break
175                 }
176         }
177 }
178
179 func (p *Proxy) closeOnSignal() {
180         ch := make(chan os.Signal, 10)
181         signal.Notify(ch, os.Interrupt, os.Kill, os.Signal(syscall.SIGTERM), os.Signal(syscall.SIGHUP))
182         sig := <-ch
183         p.ul.Close()
184         switch sig {
185         case os.Signal(syscall.SIGHUP):
186                 log.Printf("graceful shutdown from signal: %v\n", sig)
187         default:
188                 log.Fatalf("exiting from signal: %v\n", sig)
189         }
190 }