diff options
author | Joe Slater <joe.slater@windriver.com> | 2023-04-12 12:24:48 -0700 |
---|---|---|
committer | Steve Sakoman <steve@sakoman.com> | 2023-04-19 04:45:00 -1000 |
commit | 217a47bdb4bf8e5f812428fdd234bbe2696c02bd (patch) | |
tree | a569e6654cbf3dfaa7f52b4a9a6ea68cfd6bd879 /meta | |
parent | 25aa5dfcae84ed9f6a7aa2d0e14ff1f6a082b503 (diff) | |
download | poky-217a47bdb4bf8e5f812428fdd234bbe2696c02bd.tar.gz |
go: fix CVE-2022-41724, 41725
Backport from go-1.19. The godebug package is needed by
the fix to CVE-2022-41725.
Mostly a cherry-pick but exceptions are noted in comments
marked "backport".
(From OE-Core rev: e5cf04f55b4849ae6db1253b39ad8b037cf01af4)
Signed-off-by: Joe Slater <joe.slater@windriver.com>
Signed-off-by: Steve Sakoman <steve@sakoman.com>
Diffstat (limited to 'meta')
-rw-r--r-- | meta/recipes-devtools/go/go-1.17.13.inc | 5 | ||||
-rw-r--r-- | meta/recipes-devtools/go/go-1.19/add_godebug.patch | 84 | ||||
-rw-r--r-- | meta/recipes-devtools/go/go-1.19/cve-2022-41724.patch | 2391 | ||||
-rw-r--r-- | meta/recipes-devtools/go/go-1.19/cve-2022-41725.patch | 652 |
4 files changed, 3131 insertions, 1 deletions
diff --git a/meta/recipes-devtools/go/go-1.17.13.inc b/meta/recipes-devtools/go/go-1.17.13.inc index 14d58932dc..23380f04c3 100644 --- a/meta/recipes-devtools/go/go-1.17.13.inc +++ b/meta/recipes-devtools/go/go-1.17.13.inc | |||
@@ -1,6 +1,6 @@ | |||
1 | require go-common.inc | 1 | require go-common.inc |
2 | 2 | ||
3 | FILESEXTRAPATHS:prepend := "${FILE_DIRNAME}/go-1.18:" | 3 | FILESEXTRAPATHS:prepend := "${FILE_DIRNAME}/go-1.19:${FILE_DIRNAME}/go-1.18:" |
4 | 4 | ||
5 | LIC_FILES_CHKSUM = "file://LICENSE;md5=5d4950ecb7b26d2c5e4e7b4e0dd74707" | 5 | LIC_FILES_CHKSUM = "file://LICENSE;md5=5d4950ecb7b26d2c5e4e7b4e0dd74707" |
6 | 6 | ||
@@ -23,6 +23,9 @@ SRC_URI += "\ | |||
23 | file://CVE-2022-2879.patch \ | 23 | file://CVE-2022-2879.patch \ |
24 | file://CVE-2022-41720.patch \ | 24 | file://CVE-2022-41720.patch \ |
25 | file://CVE-2022-41723.patch \ | 25 | file://CVE-2022-41723.patch \ |
26 | file://cve-2022-41724.patch \ | ||
27 | file://add_godebug.patch \ | ||
28 | file://cve-2022-41725.patch \ | ||
26 | " | 29 | " |
27 | SRC_URI[main.sha256sum] = "a1a48b23afb206f95e7bbaa9b898d965f90826f6f1d1fc0c1d784ada0cd300fd" | 30 | SRC_URI[main.sha256sum] = "a1a48b23afb206f95e7bbaa9b898d965f90826f6f1d1fc0c1d784ada0cd300fd" |
28 | 31 | ||
diff --git a/meta/recipes-devtools/go/go-1.19/add_godebug.patch b/meta/recipes-devtools/go/go-1.19/add_godebug.patch new file mode 100644 index 0000000000..0c3d2d2855 --- /dev/null +++ b/meta/recipes-devtools/go/go-1.19/add_godebug.patch | |||
@@ -0,0 +1,84 @@ | |||
1 | |||
2 | Upstream-Status: Backport [see text] | ||
3 | |||
4 | https://github.com/golong/go.git as of commit 22c1d18a27... | ||
5 | Copy src/internal/godebug from go 1.19 since it does not | ||
6 | exist in 1.17. | ||
7 | |||
8 | Signed-off-by: Joe Slater <joe.slater@windriver.com> | ||
9 | --- | ||
10 | |||
11 | --- /dev/null | ||
12 | +++ go/src/internal/godebug/godebug.go | ||
13 | @@ -0,0 +1,34 @@ | ||
14 | +// Copyright 2021 The Go Authors. All rights reserved. | ||
15 | +// Use of this source code is governed by a BSD-style | ||
16 | +// license that can be found in the LICENSE file. | ||
17 | + | ||
18 | +// Package godebug parses the GODEBUG environment variable. | ||
19 | +package godebug | ||
20 | + | ||
21 | +import "os" | ||
22 | + | ||
23 | +// Get returns the value for the provided GODEBUG key. | ||
24 | +func Get(key string) string { | ||
25 | + return get(os.Getenv("GODEBUG"), key) | ||
26 | +} | ||
27 | + | ||
28 | +// get returns the value part of key=value in s (a GODEBUG value). | ||
29 | +func get(s, key string) string { | ||
30 | + for i := 0; i < len(s)-len(key)-1; i++ { | ||
31 | + if i > 0 && s[i-1] != ',' { | ||
32 | + continue | ||
33 | + } | ||
34 | + afterKey := s[i+len(key):] | ||
35 | + if afterKey[0] != '=' || s[i:i+len(key)] != key { | ||
36 | + continue | ||
37 | + } | ||
38 | + val := afterKey[1:] | ||
39 | + for i, b := range val { | ||
40 | + if b == ',' { | ||
41 | + return val[:i] | ||
42 | + } | ||
43 | + } | ||
44 | + return val | ||
45 | + } | ||
46 | + return "" | ||
47 | +} | ||
48 | --- /dev/null | ||
49 | +++ go/src/internal/godebug/godebug_test.go | ||
50 | @@ -0,0 +1,34 @@ | ||
51 | +// Copyright 2021 The Go Authors. All rights reserved. | ||
52 | +// Use of this source code is governed by a BSD-style | ||
53 | +// license that can be found in the LICENSE file. | ||
54 | + | ||
55 | +package godebug | ||
56 | + | ||
57 | +import "testing" | ||
58 | + | ||
59 | +func TestGet(t *testing.T) { | ||
60 | + tests := []struct { | ||
61 | + godebug string | ||
62 | + key string | ||
63 | + want string | ||
64 | + }{ | ||
65 | + {"", "", ""}, | ||
66 | + {"", "foo", ""}, | ||
67 | + {"foo=bar", "foo", "bar"}, | ||
68 | + {"foo=bar,after=x", "foo", "bar"}, | ||
69 | + {"before=x,foo=bar,after=x", "foo", "bar"}, | ||
70 | + {"before=x,foo=bar", "foo", "bar"}, | ||
71 | + {",,,foo=bar,,,", "foo", "bar"}, | ||
72 | + {"foodecoy=wrong,foo=bar", "foo", "bar"}, | ||
73 | + {"foo=", "foo", ""}, | ||
74 | + {"foo", "foo", ""}, | ||
75 | + {",foo", "foo", ""}, | ||
76 | + {"foo=bar,baz", "loooooooong", ""}, | ||
77 | + } | ||
78 | + for _, tt := range tests { | ||
79 | + got := get(tt.godebug, tt.key) | ||
80 | + if got != tt.want { | ||
81 | + t.Errorf("get(%q, %q) = %q; want %q", tt.godebug, tt.key, got, tt.want) | ||
82 | + } | ||
83 | + } | ||
84 | +} | ||
diff --git a/meta/recipes-devtools/go/go-1.19/cve-2022-41724.patch b/meta/recipes-devtools/go/go-1.19/cve-2022-41724.patch new file mode 100644 index 0000000000..aacffbffcd --- /dev/null +++ b/meta/recipes-devtools/go/go-1.19/cve-2022-41724.patch | |||
@@ -0,0 +1,2391 @@ | |||
1 | From 00b256e9e3c0fa02a278ec9dfc3e191e02ceaf80 Mon Sep 17 00:00:00 2001 | ||
2 | From: Roland Shoemaker <roland@golang.org> | ||
3 | Date: Wed, 14 Dec 2022 09:43:16 -0800 | ||
4 | Subject: [PATCH] [release-branch.go1.19] crypto/tls: replace all usages of | ||
5 | BytesOrPanic | ||
6 | |||
7 | Message marshalling makes use of BytesOrPanic a lot, under the | ||
8 | assumption that it will never panic. This assumption was incorrect, and | ||
9 | specifically crafted handshakes could trigger panics. Rather than just | ||
10 | surgically replacing the usages of BytesOrPanic in paths that could | ||
11 | panic, replace all usages of it with proper error returns in case there | ||
12 | are other ways of triggering panics which we didn't find. | ||
13 | |||
14 | In one specific case, the tree routed by expandLabel, we replace the | ||
15 | usage of BytesOrPanic, but retain a panic. This function already | ||
16 | explicitly panicked elsewhere, and returning an error from it becomes | ||
17 | rather painful because it requires changing a large number of APIs. | ||
18 | The marshalling is unlikely to ever panic, as the inputs are all either | ||
19 | fixed length, or already limited to the sizes required. If it were to | ||
20 | panic, it'd likely only be during development. A close inspection shows | ||
21 | no paths for a user to cause a panic currently. | ||
22 | |||
23 | This patches ends up being rather large, since it requires routing | ||
24 | errors back through functions which previously had no error returns. | ||
25 | Where possible I've tried to use helpers that reduce the verbosity | ||
26 | of frequently repeated stanzas, and to make the diffs as minimal as | ||
27 | possible. | ||
28 | |||
29 | Thanks to Marten Seemann for reporting this issue. | ||
30 | |||
31 | Updates #58001 | ||
32 | Fixes #58358 | ||
33 | Fixes CVE-2022-41724 | ||
34 | |||
35 | Change-Id: Ieb55867ef0a3e1e867b33f09421932510cb58851 | ||
36 | Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1679436 | ||
37 | Reviewed-by: Julie Qiu <julieqiu@google.com> | ||
38 | TryBot-Result: Security TryBots <security-trybots@go-security-trybots.iam.gserviceaccount.com> | ||
39 | Run-TryBot: Roland Shoemaker <bracewell@google.com> | ||
40 | Reviewed-by: Damien Neil <dneil@google.com> | ||
41 | (cherry picked from commit 0f3a44ad7b41cc89efdfad25278953e17d9c1e04) | ||
42 | Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728204 | ||
43 | Reviewed-by: Tatiana Bradley <tatianabradley@google.com> | ||
44 | Reviewed-on: https://go-review.googlesource.com/c/go/+/468117 | ||
45 | Auto-Submit: Michael Pratt <mpratt@google.com> | ||
46 | Run-TryBot: Michael Pratt <mpratt@google.com> | ||
47 | TryBot-Result: Gopher Robot <gobot@golang.org> | ||
48 | Reviewed-by: Than McIntosh <thanm@google.com> | ||
49 | --- | ||
50 | |||
51 | CVE: CVE-2022-41724 | ||
52 | |||
53 | Upstream-Status: Backport [see text] | ||
54 | |||
55 | https://github.com/golong/go.git commit 00b256e9e3c0fa... | ||
56 | boring_test.go does not exist | ||
57 | modified for conn.go and handshake_messages.go | ||
58 | |||
59 | Signed-off-by: Joe Slater <joe.slater@windriver.com> | ||
60 | |||
61 | --- | ||
62 | src/crypto/tls/boring_test.go | 2 +- | ||
63 | src/crypto/tls/common.go | 2 +- | ||
64 | src/crypto/tls/conn.go | 46 +- | ||
65 | src/crypto/tls/handshake_client.go | 95 +-- | ||
66 | src/crypto/tls/handshake_client_test.go | 4 +- | ||
67 | src/crypto/tls/handshake_client_tls13.go | 74 ++- | ||
68 | src/crypto/tls/handshake_messages.go | 716 +++++++++++----------- | ||
69 | src/crypto/tls/handshake_messages_test.go | 19 +- | ||
70 | src/crypto/tls/handshake_server.go | 73 ++- | ||
71 | src/crypto/tls/handshake_server_test.go | 31 +- | ||
72 | src/crypto/tls/handshake_server_tls13.go | 71 ++- | ||
73 | src/crypto/tls/key_schedule.go | 19 +- | ||
74 | src/crypto/tls/ticket.go | 8 +- | ||
75 | 13 files changed, 657 insertions(+), 503 deletions(-) | ||
76 | |||
77 | --- go.orig/src/crypto/tls/common.go | ||
78 | +++ go/src/crypto/tls/common.go | ||
79 | @@ -1357,7 +1357,7 @@ func (c *Certificate) leaf() (*x509.Cert | ||
80 | } | ||
81 | |||
82 | type handshakeMessage interface { | ||
83 | - marshal() []byte | ||
84 | + marshal() ([]byte, error) | ||
85 | unmarshal([]byte) bool | ||
86 | } | ||
87 | |||
88 | --- go.orig/src/crypto/tls/conn.go | ||
89 | +++ go/src/crypto/tls/conn.go | ||
90 | @@ -994,18 +994,46 @@ func (c *Conn) writeRecordLocked(typ rec | ||
91 | return n, nil | ||
92 | } | ||
93 | |||
94 | -// writeRecord writes a TLS record with the given type and payload to the | ||
95 | -// connection and updates the record layer state. | ||
96 | -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { | ||
97 | +// writeHandshakeRecord writes a handshake message to the connection and updates | ||
98 | +// the record layer state. If transcript is non-nil the marshalled message is | ||
99 | +// written to it. | ||
100 | +func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { | ||
101 | c.out.Lock() | ||
102 | defer c.out.Unlock() | ||
103 | |||
104 | - return c.writeRecordLocked(typ, data) | ||
105 | + data, err := msg.marshal() | ||
106 | + if err != nil { | ||
107 | + return 0, err | ||
108 | + } | ||
109 | + if transcript != nil { | ||
110 | + transcript.Write(data) | ||
111 | + } | ||
112 | + | ||
113 | + return c.writeRecordLocked(recordTypeHandshake, data) | ||
114 | +} | ||
115 | + | ||
116 | +// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and | ||
117 | +// updates the record layer state. | ||
118 | +func (c *Conn) writeChangeCipherRecord() error { | ||
119 | + c.out.Lock() | ||
120 | + defer c.out.Unlock() | ||
121 | + _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) | ||
122 | + return err | ||
123 | } | ||
124 | |||
125 | // readHandshake reads the next handshake message from | ||
126 | -// the record layer. | ||
127 | -func (c *Conn) readHandshake() (interface{}, error) { | ||
128 | +// the record layer. If transcript is non-nil, the message | ||
129 | +// is written to the passed transcriptHash. | ||
130 | + | ||
131 | +// backport 00b256e9e3c0fa02a278ec9dfc3e191e02ceaf80 | ||
132 | +// | ||
133 | +// Commit wants to set this to | ||
134 | +// | ||
135 | +// func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { | ||
136 | +// | ||
137 | +// but that does not compile. Retain the original interface{} argument. | ||
138 | +// | ||
139 | +func (c *Conn) readHandshake(transcript transcriptHash) (interface{}, error) { | ||
140 | for c.hand.Len() < 4 { | ||
141 | if err := c.readRecord(); err != nil { | ||
142 | return nil, err | ||
143 | @@ -1084,6 +1112,11 @@ func (c *Conn) readHandshake() (interfac | ||
144 | if !m.unmarshal(data) { | ||
145 | return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) | ||
146 | } | ||
147 | + | ||
148 | + if transcript != nil { | ||
149 | + transcript.Write(data) | ||
150 | + } | ||
151 | + | ||
152 | return m, nil | ||
153 | } | ||
154 | |||
155 | @@ -1159,7 +1192,7 @@ func (c *Conn) handleRenegotiation() err | ||
156 | return errors.New("tls: internal error: unexpected renegotiation") | ||
157 | } | ||
158 | |||
159 | - msg, err := c.readHandshake() | ||
160 | + msg, err := c.readHandshake(nil) | ||
161 | if err != nil { | ||
162 | return err | ||
163 | } | ||
164 | @@ -1205,7 +1238,7 @@ func (c *Conn) handlePostHandshakeMessag | ||
165 | return c.handleRenegotiation() | ||
166 | } | ||
167 | |||
168 | - msg, err := c.readHandshake() | ||
169 | + msg, err := c.readHandshake(nil) | ||
170 | if err != nil { | ||
171 | return err | ||
172 | } | ||
173 | @@ -1241,7 +1274,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate | ||
174 | defer c.out.Unlock() | ||
175 | |||
176 | msg := &keyUpdateMsg{} | ||
177 | - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) | ||
178 | + msgBytes, err := msg.marshal() | ||
179 | + if err != nil { | ||
180 | + return err | ||
181 | + } | ||
182 | + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) | ||
183 | if err != nil { | ||
184 | // Surface the error at the next write. | ||
185 | c.out.setErrorLocked(err) | ||
186 | --- go.orig/src/crypto/tls/handshake_client.go | ||
187 | +++ go/src/crypto/tls/handshake_client.go | ||
188 | @@ -157,7 +157,10 @@ func (c *Conn) clientHandshake(ctx conte | ||
189 | } | ||
190 | c.serverName = hello.serverName | ||
191 | |||
192 | - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) | ||
193 | + cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) | ||
194 | + if err != nil { | ||
195 | + return err | ||
196 | + } | ||
197 | if cacheKey != "" && session != nil { | ||
198 | defer func() { | ||
199 | // If we got a handshake failure when resuming a session, throw away | ||
200 | @@ -172,11 +175,12 @@ func (c *Conn) clientHandshake(ctx conte | ||
201 | }() | ||
202 | } | ||
203 | |||
204 | - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { | ||
205 | + if _, err := c.writeHandshakeRecord(hello, nil); err != nil { | ||
206 | return err | ||
207 | } | ||
208 | |||
209 | - msg, err := c.readHandshake() | ||
210 | + // serverHelloMsg is not included in the transcript | ||
211 | + msg, err := c.readHandshake(nil) | ||
212 | if err != nil { | ||
213 | return err | ||
214 | } | ||
215 | @@ -241,9 +245,9 @@ func (c *Conn) clientHandshake(ctx conte | ||
216 | } | ||
217 | |||
218 | func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, | ||
219 | - session *ClientSessionState, earlySecret, binderKey []byte) { | ||
220 | + session *ClientSessionState, earlySecret, binderKey []byte, err error) { | ||
221 | if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { | ||
222 | - return "", nil, nil, nil | ||
223 | + return "", nil, nil, nil, nil | ||
224 | } | ||
225 | |||
226 | hello.ticketSupported = true | ||
227 | @@ -258,14 +262,14 @@ func (c *Conn) loadSession(hello *client | ||
228 | // renegotiation is primarily used to allow a client to send a client | ||
229 | // certificate, which would be skipped if session resumption occurred. | ||
230 | if c.handshakes != 0 { | ||
231 | - return "", nil, nil, nil | ||
232 | + return "", nil, nil, nil, nil | ||
233 | } | ||
234 | |||
235 | // Try to resume a previously negotiated TLS session, if available. | ||
236 | cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) | ||
237 | session, ok := c.config.ClientSessionCache.Get(cacheKey) | ||
238 | if !ok || session == nil { | ||
239 | - return cacheKey, nil, nil, nil | ||
240 | + return cacheKey, nil, nil, nil, nil | ||
241 | } | ||
242 | |||
243 | // Check that version used for the previous session is still valid. | ||
244 | @@ -277,7 +281,7 @@ func (c *Conn) loadSession(hello *client | ||
245 | } | ||
246 | } | ||
247 | if !versOk { | ||
248 | - return cacheKey, nil, nil, nil | ||
249 | + return cacheKey, nil, nil, nil, nil | ||
250 | } | ||
251 | |||
252 | // Check that the cached server certificate is not expired, and that it's | ||
253 | @@ -286,16 +290,16 @@ func (c *Conn) loadSession(hello *client | ||
254 | if !c.config.InsecureSkipVerify { | ||
255 | if len(session.verifiedChains) == 0 { | ||
256 | // The original connection had InsecureSkipVerify, while this doesn't. | ||
257 | - return cacheKey, nil, nil, nil | ||
258 | + return cacheKey, nil, nil, nil, nil | ||
259 | } | ||
260 | serverCert := session.serverCertificates[0] | ||
261 | if c.config.time().After(serverCert.NotAfter) { | ||
262 | // Expired certificate, delete the entry. | ||
263 | c.config.ClientSessionCache.Put(cacheKey, nil) | ||
264 | - return cacheKey, nil, nil, nil | ||
265 | + return cacheKey, nil, nil, nil, nil | ||
266 | } | ||
267 | if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { | ||
268 | - return cacheKey, nil, nil, nil | ||
269 | + return cacheKey, nil, nil, nil, nil | ||
270 | } | ||
271 | } | ||
272 | |||
273 | @@ -303,7 +307,7 @@ func (c *Conn) loadSession(hello *client | ||
274 | // In TLS 1.2 the cipher suite must match the resumed session. Ensure we | ||
275 | // are still offering it. | ||
276 | if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { | ||
277 | - return cacheKey, nil, nil, nil | ||
278 | + return cacheKey, nil, nil, nil, nil | ||
279 | } | ||
280 | |||
281 | hello.sessionTicket = session.sessionTicket | ||
282 | @@ -313,14 +317,14 @@ func (c *Conn) loadSession(hello *client | ||
283 | // Check that the session ticket is not expired. | ||
284 | if c.config.time().After(session.useBy) { | ||
285 | c.config.ClientSessionCache.Put(cacheKey, nil) | ||
286 | - return cacheKey, nil, nil, nil | ||
287 | + return cacheKey, nil, nil, nil, nil | ||
288 | } | ||
289 | |||
290 | // In TLS 1.3 the KDF hash must match the resumed session. Ensure we | ||
291 | // offer at least one cipher suite with that hash. | ||
292 | cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) | ||
293 | if cipherSuite == nil { | ||
294 | - return cacheKey, nil, nil, nil | ||
295 | + return cacheKey, nil, nil, nil, nil | ||
296 | } | ||
297 | cipherSuiteOk := false | ||
298 | for _, offeredID := range hello.cipherSuites { | ||
299 | @@ -331,7 +335,7 @@ func (c *Conn) loadSession(hello *client | ||
300 | } | ||
301 | } | ||
302 | if !cipherSuiteOk { | ||
303 | - return cacheKey, nil, nil, nil | ||
304 | + return cacheKey, nil, nil, nil, nil | ||
305 | } | ||
306 | |||
307 | // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. | ||
308 | @@ -349,9 +353,15 @@ func (c *Conn) loadSession(hello *client | ||
309 | earlySecret = cipherSuite.extract(psk, nil) | ||
310 | binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) | ||
311 | transcript := cipherSuite.hash.New() | ||
312 | - transcript.Write(hello.marshalWithoutBinders()) | ||
313 | + helloBytes, err := hello.marshalWithoutBinders() | ||
314 | + if err != nil { | ||
315 | + return "", nil, nil, nil, err | ||
316 | + } | ||
317 | + transcript.Write(helloBytes) | ||
318 | pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} | ||
319 | - hello.updateBinders(pskBinders) | ||
320 | + if err := hello.updateBinders(pskBinders); err != nil { | ||
321 | + return "", nil, nil, nil, err | ||
322 | + } | ||
323 | |||
324 | return | ||
325 | } | ||
326 | @@ -396,8 +406,12 @@ func (hs *clientHandshakeState) handshak | ||
327 | hs.finishedHash.discardHandshakeBuffer() | ||
328 | } | ||
329 | |||
330 | - hs.finishedHash.Write(hs.hello.marshal()) | ||
331 | - hs.finishedHash.Write(hs.serverHello.marshal()) | ||
332 | + if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { | ||
333 | + return err | ||
334 | + } | ||
335 | + if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { | ||
336 | + return err | ||
337 | + } | ||
338 | |||
339 | c.buffering = true | ||
340 | c.didResume = isResume | ||
341 | @@ -468,7 +482,7 @@ func (hs *clientHandshakeState) pickCiph | ||
342 | func (hs *clientHandshakeState) doFullHandshake() error { | ||
343 | c := hs.c | ||
344 | |||
345 | - msg, err := c.readHandshake() | ||
346 | + msg, err := c.readHandshake(&hs.finishedHash) | ||
347 | if err != nil { | ||
348 | return err | ||
349 | } | ||
350 | @@ -477,9 +491,8 @@ func (hs *clientHandshakeState) doFullHa | ||
351 | c.sendAlert(alertUnexpectedMessage) | ||
352 | return unexpectedMessageError(certMsg, msg) | ||
353 | } | ||
354 | - hs.finishedHash.Write(certMsg.marshal()) | ||
355 | |||
356 | - msg, err = c.readHandshake() | ||
357 | + msg, err = c.readHandshake(&hs.finishedHash) | ||
358 | if err != nil { | ||
359 | return err | ||
360 | } | ||
361 | @@ -497,11 +510,10 @@ func (hs *clientHandshakeState) doFullHa | ||
362 | c.sendAlert(alertUnexpectedMessage) | ||
363 | return errors.New("tls: received unexpected CertificateStatus message") | ||
364 | } | ||
365 | - hs.finishedHash.Write(cs.marshal()) | ||
366 | |||
367 | c.ocspResponse = cs.response | ||
368 | |||
369 | - msg, err = c.readHandshake() | ||
370 | + msg, err = c.readHandshake(&hs.finishedHash) | ||
371 | if err != nil { | ||
372 | return err | ||
373 | } | ||
374 | @@ -530,14 +542,13 @@ func (hs *clientHandshakeState) doFullHa | ||
375 | |||
376 | skx, ok := msg.(*serverKeyExchangeMsg) | ||
377 | if ok { | ||
378 | - hs.finishedHash.Write(skx.marshal()) | ||
379 | err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) | ||
380 | if err != nil { | ||
381 | c.sendAlert(alertUnexpectedMessage) | ||
382 | return err | ||
383 | } | ||
384 | |||
385 | - msg, err = c.readHandshake() | ||
386 | + msg, err = c.readHandshake(&hs.finishedHash) | ||
387 | if err != nil { | ||
388 | return err | ||
389 | } | ||
390 | @@ -548,7 +559,6 @@ func (hs *clientHandshakeState) doFullHa | ||
391 | certReq, ok := msg.(*certificateRequestMsg) | ||
392 | if ok { | ||
393 | certRequested = true | ||
394 | - hs.finishedHash.Write(certReq.marshal()) | ||
395 | |||
396 | cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) | ||
397 | if chainToSend, err = c.getClientCertificate(cri); err != nil { | ||
398 | @@ -556,7 +566,7 @@ func (hs *clientHandshakeState) doFullHa | ||
399 | return err | ||
400 | } | ||
401 | |||
402 | - msg, err = c.readHandshake() | ||
403 | + msg, err = c.readHandshake(&hs.finishedHash) | ||
404 | if err != nil { | ||
405 | return err | ||
406 | } | ||
407 | @@ -567,7 +577,6 @@ func (hs *clientHandshakeState) doFullHa | ||
408 | c.sendAlert(alertUnexpectedMessage) | ||
409 | return unexpectedMessageError(shd, msg) | ||
410 | } | ||
411 | - hs.finishedHash.Write(shd.marshal()) | ||
412 | |||
413 | // If the server requested a certificate then we have to send a | ||
414 | // Certificate message, even if it's empty because we don't have a | ||
415 | @@ -575,8 +584,7 @@ func (hs *clientHandshakeState) doFullHa | ||
416 | if certRequested { | ||
417 | certMsg = new(certificateMsg) | ||
418 | certMsg.certificates = chainToSend.Certificate | ||
419 | - hs.finishedHash.Write(certMsg.marshal()) | ||
420 | - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { | ||
421 | + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { | ||
422 | return err | ||
423 | } | ||
424 | } | ||
425 | @@ -587,8 +595,7 @@ func (hs *clientHandshakeState) doFullHa | ||
426 | return err | ||
427 | } | ||
428 | if ckx != nil { | ||
429 | - hs.finishedHash.Write(ckx.marshal()) | ||
430 | - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { | ||
431 | + if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { | ||
432 | return err | ||
433 | } | ||
434 | } | ||
435 | @@ -635,8 +642,7 @@ func (hs *clientHandshakeState) doFullHa | ||
436 | return err | ||
437 | } | ||
438 | |||
439 | - hs.finishedHash.Write(certVerify.marshal()) | ||
440 | - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { | ||
441 | + if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { | ||
442 | return err | ||
443 | } | ||
444 | } | ||
445 | @@ -771,7 +777,10 @@ func (hs *clientHandshakeState) readFini | ||
446 | return err | ||
447 | } | ||
448 | |||
449 | - msg, err := c.readHandshake() | ||
450 | + // finishedMsg is included in the transcript, but not until after we | ||
451 | + // check the client version, since the state before this message was | ||
452 | + // sent is used during verification. | ||
453 | + msg, err := c.readHandshake(nil) | ||
454 | if err != nil { | ||
455 | return err | ||
456 | } | ||
457 | @@ -787,7 +796,11 @@ func (hs *clientHandshakeState) readFini | ||
458 | c.sendAlert(alertHandshakeFailure) | ||
459 | return errors.New("tls: server's Finished message was incorrect") | ||
460 | } | ||
461 | - hs.finishedHash.Write(serverFinished.marshal()) | ||
462 | + | ||
463 | + if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { | ||
464 | + return err | ||
465 | + } | ||
466 | + | ||
467 | copy(out, verify) | ||
468 | return nil | ||
469 | } | ||
470 | @@ -798,7 +811,7 @@ func (hs *clientHandshakeState) readSess | ||
471 | } | ||
472 | |||
473 | c := hs.c | ||
474 | - msg, err := c.readHandshake() | ||
475 | + msg, err := c.readHandshake(&hs.finishedHash) | ||
476 | if err != nil { | ||
477 | return err | ||
478 | } | ||
479 | @@ -807,7 +820,6 @@ func (hs *clientHandshakeState) readSess | ||
480 | c.sendAlert(alertUnexpectedMessage) | ||
481 | return unexpectedMessageError(sessionTicketMsg, msg) | ||
482 | } | ||
483 | - hs.finishedHash.Write(sessionTicketMsg.marshal()) | ||
484 | |||
485 | hs.session = &ClientSessionState{ | ||
486 | sessionTicket: sessionTicketMsg.ticket, | ||
487 | @@ -827,14 +839,13 @@ func (hs *clientHandshakeState) readSess | ||
488 | func (hs *clientHandshakeState) sendFinished(out []byte) error { | ||
489 | c := hs.c | ||
490 | |||
491 | - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { | ||
492 | + if err := c.writeChangeCipherRecord(); err != nil { | ||
493 | return err | ||
494 | } | ||
495 | |||
496 | finished := new(finishedMsg) | ||
497 | finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) | ||
498 | - hs.finishedHash.Write(finished.marshal()) | ||
499 | - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { | ||
500 | + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { | ||
501 | return err | ||
502 | } | ||
503 | copy(out, finished.verifyData) | ||
504 | --- go.orig/src/crypto/tls/handshake_client_test.go | ||
505 | +++ go/src/crypto/tls/handshake_client_test.go | ||
506 | @@ -1257,7 +1257,7 @@ func TestServerSelectingUnconfiguredAppl | ||
507 | cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256, | ||
508 | alpnProtocol: "how-about-this", | ||
509 | } | ||
510 | - serverHelloBytes := serverHello.marshal() | ||
511 | + serverHelloBytes := mustMarshal(t, serverHello) | ||
512 | |||
513 | s.Write([]byte{ | ||
514 | byte(recordTypeHandshake), | ||
515 | @@ -1500,7 +1500,7 @@ func TestServerSelectingUnconfiguredCiph | ||
516 | random: make([]byte, 32), | ||
517 | cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, | ||
518 | } | ||
519 | - serverHelloBytes := serverHello.marshal() | ||
520 | + serverHelloBytes := mustMarshal(t, serverHello) | ||
521 | |||
522 | s.Write([]byte{ | ||
523 | byte(recordTypeHandshake), | ||
524 | --- go.orig/src/crypto/tls/handshake_client_tls13.go | ||
525 | +++ go/src/crypto/tls/handshake_client_tls13.go | ||
526 | @@ -58,7 +58,10 @@ func (hs *clientHandshakeStateTLS13) han | ||
527 | } | ||
528 | |||
529 | hs.transcript = hs.suite.hash.New() | ||
530 | - hs.transcript.Write(hs.hello.marshal()) | ||
531 | + | ||
532 | + if err := transcriptMsg(hs.hello, hs.transcript); err != nil { | ||
533 | + return err | ||
534 | + } | ||
535 | |||
536 | if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { | ||
537 | if err := hs.sendDummyChangeCipherSpec(); err != nil { | ||
538 | @@ -69,7 +72,9 @@ func (hs *clientHandshakeStateTLS13) han | ||
539 | } | ||
540 | } | ||
541 | |||
542 | - hs.transcript.Write(hs.serverHello.marshal()) | ||
543 | + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { | ||
544 | + return err | ||
545 | + } | ||
546 | |||
547 | c.buffering = true | ||
548 | if err := hs.processServerHello(); err != nil { | ||
549 | @@ -168,8 +173,7 @@ func (hs *clientHandshakeStateTLS13) sen | ||
550 | } | ||
551 | hs.sentDummyCCS = true | ||
552 | |||
553 | - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) | ||
554 | - return err | ||
555 | + return hs.c.writeChangeCipherRecord() | ||
556 | } | ||
557 | |||
558 | // processHelloRetryRequest handles the HRR in hs.serverHello, modifies and | ||
559 | @@ -184,7 +188,9 @@ func (hs *clientHandshakeStateTLS13) pro | ||
560 | hs.transcript.Reset() | ||
561 | hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) | ||
562 | hs.transcript.Write(chHash) | ||
563 | - hs.transcript.Write(hs.serverHello.marshal()) | ||
564 | + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { | ||
565 | + return err | ||
566 | + } | ||
567 | |||
568 | // The only HelloRetryRequest extensions we support are key_share and | ||
569 | // cookie, and clients must abort the handshake if the HRR would not result | ||
570 | @@ -249,10 +255,18 @@ func (hs *clientHandshakeStateTLS13) pro | ||
571 | transcript := hs.suite.hash.New() | ||
572 | transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) | ||
573 | transcript.Write(chHash) | ||
574 | - transcript.Write(hs.serverHello.marshal()) | ||
575 | - transcript.Write(hs.hello.marshalWithoutBinders()) | ||
576 | + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { | ||
577 | + return err | ||
578 | + } | ||
579 | + helloBytes, err := hs.hello.marshalWithoutBinders() | ||
580 | + if err != nil { | ||
581 | + return err | ||
582 | + } | ||
583 | + transcript.Write(helloBytes) | ||
584 | pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} | ||
585 | - hs.hello.updateBinders(pskBinders) | ||
586 | + if err := hs.hello.updateBinders(pskBinders); err != nil { | ||
587 | + return err | ||
588 | + } | ||
589 | } else { | ||
590 | // Server selected a cipher suite incompatible with the PSK. | ||
591 | hs.hello.pskIdentities = nil | ||
592 | @@ -260,12 +274,12 @@ func (hs *clientHandshakeStateTLS13) pro | ||
593 | } | ||
594 | } | ||
595 | |||
596 | - hs.transcript.Write(hs.hello.marshal()) | ||
597 | - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { | ||
598 | + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { | ||
599 | return err | ||
600 | } | ||
601 | |||
602 | - msg, err := c.readHandshake() | ||
603 | + // serverHelloMsg is not included in the transcript | ||
604 | + msg, err := c.readHandshake(nil) | ||
605 | if err != nil { | ||
606 | return err | ||
607 | } | ||
608 | @@ -354,6 +368,7 @@ func (hs *clientHandshakeStateTLS13) est | ||
609 | if !hs.usingPSK { | ||
610 | earlySecret = hs.suite.extract(nil, nil) | ||
611 | } | ||
612 | + | ||
613 | handshakeSecret := hs.suite.extract(sharedKey, | ||
614 | hs.suite.deriveSecret(earlySecret, "derived", nil)) | ||
615 | |||
616 | @@ -384,7 +399,7 @@ func (hs *clientHandshakeStateTLS13) est | ||
617 | func (hs *clientHandshakeStateTLS13) readServerParameters() error { | ||
618 | c := hs.c | ||
619 | |||
620 | - msg, err := c.readHandshake() | ||
621 | + msg, err := c.readHandshake(hs.transcript) | ||
622 | if err != nil { | ||
623 | return err | ||
624 | } | ||
625 | @@ -394,7 +409,6 @@ func (hs *clientHandshakeStateTLS13) rea | ||
626 | c.sendAlert(alertUnexpectedMessage) | ||
627 | return unexpectedMessageError(encryptedExtensions, msg) | ||
628 | } | ||
629 | - hs.transcript.Write(encryptedExtensions.marshal()) | ||
630 | |||
631 | if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { | ||
632 | c.sendAlert(alertUnsupportedExtension) | ||
633 | @@ -423,18 +437,16 @@ func (hs *clientHandshakeStateTLS13) rea | ||
634 | return nil | ||
635 | } | ||
636 | |||
637 | - msg, err := c.readHandshake() | ||
638 | + msg, err := c.readHandshake(hs.transcript) | ||
639 | if err != nil { | ||
640 | return err | ||
641 | } | ||
642 | |||
643 | certReq, ok := msg.(*certificateRequestMsgTLS13) | ||
644 | if ok { | ||
645 | - hs.transcript.Write(certReq.marshal()) | ||
646 | - | ||
647 | hs.certReq = certReq | ||
648 | |||
649 | - msg, err = c.readHandshake() | ||
650 | + msg, err = c.readHandshake(hs.transcript) | ||
651 | if err != nil { | ||
652 | return err | ||
653 | } | ||
654 | @@ -449,7 +461,6 @@ func (hs *clientHandshakeStateTLS13) rea | ||
655 | c.sendAlert(alertDecodeError) | ||
656 | return errors.New("tls: received empty certificates message") | ||
657 | } | ||
658 | - hs.transcript.Write(certMsg.marshal()) | ||
659 | |||
660 | c.scts = certMsg.certificate.SignedCertificateTimestamps | ||
661 | c.ocspResponse = certMsg.certificate.OCSPStaple | ||
662 | @@ -458,7 +469,10 @@ func (hs *clientHandshakeStateTLS13) rea | ||
663 | return err | ||
664 | } | ||
665 | |||
666 | - msg, err = c.readHandshake() | ||
667 | + // certificateVerifyMsg is included in the transcript, but not until | ||
668 | + // after we verify the handshake signature, since the state before | ||
669 | + // this message was sent is used. | ||
670 | + msg, err = c.readHandshake(nil) | ||
671 | if err != nil { | ||
672 | return err | ||
673 | } | ||
674 | @@ -489,7 +503,9 @@ func (hs *clientHandshakeStateTLS13) rea | ||
675 | return errors.New("tls: invalid signature by the server certificate: " + err.Error()) | ||
676 | } | ||
677 | |||
678 | - hs.transcript.Write(certVerify.marshal()) | ||
679 | + if err := transcriptMsg(certVerify, hs.transcript); err != nil { | ||
680 | + return err | ||
681 | + } | ||
682 | |||
683 | return nil | ||
684 | } | ||
685 | @@ -497,7 +513,10 @@ func (hs *clientHandshakeStateTLS13) rea | ||
686 | func (hs *clientHandshakeStateTLS13) readServerFinished() error { | ||
687 | c := hs.c | ||
688 | |||
689 | - msg, err := c.readHandshake() | ||
690 | + // finishedMsg is included in the transcript, but not until after we | ||
691 | + // check the client version, since the state before this message was | ||
692 | + // sent is used during verification. | ||
693 | + msg, err := c.readHandshake(nil) | ||
694 | if err != nil { | ||
695 | return err | ||
696 | } | ||
697 | @@ -514,7 +533,9 @@ func (hs *clientHandshakeStateTLS13) rea | ||
698 | return errors.New("tls: invalid server finished hash") | ||
699 | } | ||
700 | |||
701 | - hs.transcript.Write(finished.marshal()) | ||
702 | + if err := transcriptMsg(finished, hs.transcript); err != nil { | ||
703 | + return err | ||
704 | + } | ||
705 | |||
706 | // Derive secrets that take context through the server Finished. | ||
707 | |||
708 | @@ -563,8 +584,7 @@ func (hs *clientHandshakeStateTLS13) sen | ||
709 | certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 | ||
710 | certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 | ||
711 | |||
712 | - hs.transcript.Write(certMsg.marshal()) | ||
713 | - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { | ||
714 | + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { | ||
715 | return err | ||
716 | } | ||
717 | |||
718 | @@ -601,8 +621,7 @@ func (hs *clientHandshakeStateTLS13) sen | ||
719 | } | ||
720 | certVerifyMsg.signature = sig | ||
721 | |||
722 | - hs.transcript.Write(certVerifyMsg.marshal()) | ||
723 | - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { | ||
724 | + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { | ||
725 | return err | ||
726 | } | ||
727 | |||
728 | @@ -616,8 +635,7 @@ func (hs *clientHandshakeStateTLS13) sen | ||
729 | verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), | ||
730 | } | ||
731 | |||
732 | - hs.transcript.Write(finished.marshal()) | ||
733 | - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { | ||
734 | + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { | ||
735 | return err | ||
736 | } | ||
737 | |||
738 | --- go.orig/src/crypto/tls/handshake_messages.go | ||
739 | +++ go/src/crypto/tls/handshake_messages.go | ||
740 | @@ -5,6 +5,7 @@ | ||
741 | package tls | ||
742 | |||
743 | import ( | ||
744 | + "errors" | ||
745 | "fmt" | ||
746 | "strings" | ||
747 | |||
748 | @@ -94,9 +95,181 @@ type clientHelloMsg struct { | ||
749 | pskBinders [][]byte | ||
750 | } | ||
751 | |||
752 | -func (m *clientHelloMsg) marshal() []byte { | ||
753 | +func (m *clientHelloMsg) marshal() ([]byte, error) { | ||
754 | if m.raw != nil { | ||
755 | - return m.raw | ||
756 | + return m.raw, nil | ||
757 | + } | ||
758 | + | ||
759 | + var exts cryptobyte.Builder | ||
760 | + if len(m.serverName) > 0 { | ||
761 | + // RFC 6066, Section 3 | ||
762 | + exts.AddUint16(extensionServerName) | ||
763 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
764 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
765 | + exts.AddUint8(0) // name_type = host_name | ||
766 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
767 | + exts.AddBytes([]byte(m.serverName)) | ||
768 | + }) | ||
769 | + }) | ||
770 | + }) | ||
771 | + } | ||
772 | + if m.ocspStapling { | ||
773 | + // RFC 4366, Section 3.6 | ||
774 | + exts.AddUint16(extensionStatusRequest) | ||
775 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
776 | + exts.AddUint8(1) // status_type = ocsp | ||
777 | + exts.AddUint16(0) // empty responder_id_list | ||
778 | + exts.AddUint16(0) // empty request_extensions | ||
779 | + }) | ||
780 | + } | ||
781 | + if len(m.supportedCurves) > 0 { | ||
782 | + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 | ||
783 | + exts.AddUint16(extensionSupportedCurves) | ||
784 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
785 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
786 | + for _, curve := range m.supportedCurves { | ||
787 | + exts.AddUint16(uint16(curve)) | ||
788 | + } | ||
789 | + }) | ||
790 | + }) | ||
791 | + } | ||
792 | + if len(m.supportedPoints) > 0 { | ||
793 | + // RFC 4492, Section 5.1.2 | ||
794 | + exts.AddUint16(extensionSupportedPoints) | ||
795 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
796 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
797 | + exts.AddBytes(m.supportedPoints) | ||
798 | + }) | ||
799 | + }) | ||
800 | + } | ||
801 | + if m.ticketSupported { | ||
802 | + // RFC 5077, Section 3.2 | ||
803 | + exts.AddUint16(extensionSessionTicket) | ||
804 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
805 | + exts.AddBytes(m.sessionTicket) | ||
806 | + }) | ||
807 | + } | ||
808 | + if len(m.supportedSignatureAlgorithms) > 0 { | ||
809 | + // RFC 5246, Section 7.4.1.4.1 | ||
810 | + exts.AddUint16(extensionSignatureAlgorithms) | ||
811 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
812 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
813 | + for _, sigAlgo := range m.supportedSignatureAlgorithms { | ||
814 | + exts.AddUint16(uint16(sigAlgo)) | ||
815 | + } | ||
816 | + }) | ||
817 | + }) | ||
818 | + } | ||
819 | + if len(m.supportedSignatureAlgorithmsCert) > 0 { | ||
820 | + // RFC 8446, Section 4.2.3 | ||
821 | + exts.AddUint16(extensionSignatureAlgorithmsCert) | ||
822 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
823 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
824 | + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { | ||
825 | + exts.AddUint16(uint16(sigAlgo)) | ||
826 | + } | ||
827 | + }) | ||
828 | + }) | ||
829 | + } | ||
830 | + if m.secureRenegotiationSupported { | ||
831 | + // RFC 5746, Section 3.2 | ||
832 | + exts.AddUint16(extensionRenegotiationInfo) | ||
833 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
834 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
835 | + exts.AddBytes(m.secureRenegotiation) | ||
836 | + }) | ||
837 | + }) | ||
838 | + } | ||
839 | + if len(m.alpnProtocols) > 0 { | ||
840 | + // RFC 7301, Section 3.1 | ||
841 | + exts.AddUint16(extensionALPN) | ||
842 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
843 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
844 | + for _, proto := range m.alpnProtocols { | ||
845 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
846 | + exts.AddBytes([]byte(proto)) | ||
847 | + }) | ||
848 | + } | ||
849 | + }) | ||
850 | + }) | ||
851 | + } | ||
852 | + if m.scts { | ||
853 | + // RFC 6962, Section 3.3.1 | ||
854 | + exts.AddUint16(extensionSCT) | ||
855 | + exts.AddUint16(0) // empty extension_data | ||
856 | + } | ||
857 | + if len(m.supportedVersions) > 0 { | ||
858 | + // RFC 8446, Section 4.2.1 | ||
859 | + exts.AddUint16(extensionSupportedVersions) | ||
860 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
861 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
862 | + for _, vers := range m.supportedVersions { | ||
863 | + exts.AddUint16(vers) | ||
864 | + } | ||
865 | + }) | ||
866 | + }) | ||
867 | + } | ||
868 | + if len(m.cookie) > 0 { | ||
869 | + // RFC 8446, Section 4.2.2 | ||
870 | + exts.AddUint16(extensionCookie) | ||
871 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
872 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
873 | + exts.AddBytes(m.cookie) | ||
874 | + }) | ||
875 | + }) | ||
876 | + } | ||
877 | + if len(m.keyShares) > 0 { | ||
878 | + // RFC 8446, Section 4.2.8 | ||
879 | + exts.AddUint16(extensionKeyShare) | ||
880 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
881 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
882 | + for _, ks := range m.keyShares { | ||
883 | + exts.AddUint16(uint16(ks.group)) | ||
884 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
885 | + exts.AddBytes(ks.data) | ||
886 | + }) | ||
887 | + } | ||
888 | + }) | ||
889 | + }) | ||
890 | + } | ||
891 | + if m.earlyData { | ||
892 | + // RFC 8446, Section 4.2.10 | ||
893 | + exts.AddUint16(extensionEarlyData) | ||
894 | + exts.AddUint16(0) // empty extension_data | ||
895 | + } | ||
896 | + if len(m.pskModes) > 0 { | ||
897 | + // RFC 8446, Section 4.2.9 | ||
898 | + exts.AddUint16(extensionPSKModes) | ||
899 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
900 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
901 | + exts.AddBytes(m.pskModes) | ||
902 | + }) | ||
903 | + }) | ||
904 | + } | ||
905 | + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension | ||
906 | + // RFC 8446, Section 4.2.11 | ||
907 | + exts.AddUint16(extensionPreSharedKey) | ||
908 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
909 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
910 | + for _, psk := range m.pskIdentities { | ||
911 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
912 | + exts.AddBytes(psk.label) | ||
913 | + }) | ||
914 | + exts.AddUint32(psk.obfuscatedTicketAge) | ||
915 | + } | ||
916 | + }) | ||
917 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
918 | + for _, binder := range m.pskBinders { | ||
919 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
920 | + exts.AddBytes(binder) | ||
921 | + }) | ||
922 | + } | ||
923 | + }) | ||
924 | + }) | ||
925 | + } | ||
926 | + extBytes, err := exts.Bytes() | ||
927 | + if err != nil { | ||
928 | + return nil, err | ||
929 | } | ||
930 | |||
931 | var b cryptobyte.Builder | ||
932 | @@ -116,219 +289,53 @@ func (m *clientHelloMsg) marshal() []byt | ||
933 | b.AddBytes(m.compressionMethods) | ||
934 | }) | ||
935 | |||
936 | - // If extensions aren't present, omit them. | ||
937 | - var extensionsPresent bool | ||
938 | - bWithoutExtensions := *b | ||
939 | - | ||
940 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
941 | - if len(m.serverName) > 0 { | ||
942 | - // RFC 6066, Section 3 | ||
943 | - b.AddUint16(extensionServerName) | ||
944 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
945 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
946 | - b.AddUint8(0) // name_type = host_name | ||
947 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
948 | - b.AddBytes([]byte(m.serverName)) | ||
949 | - }) | ||
950 | - }) | ||
951 | - }) | ||
952 | - } | ||
953 | - if m.ocspStapling { | ||
954 | - // RFC 4366, Section 3.6 | ||
955 | - b.AddUint16(extensionStatusRequest) | ||
956 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
957 | - b.AddUint8(1) // status_type = ocsp | ||
958 | - b.AddUint16(0) // empty responder_id_list | ||
959 | - b.AddUint16(0) // empty request_extensions | ||
960 | - }) | ||
961 | - } | ||
962 | - if len(m.supportedCurves) > 0 { | ||
963 | - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 | ||
964 | - b.AddUint16(extensionSupportedCurves) | ||
965 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
966 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
967 | - for _, curve := range m.supportedCurves { | ||
968 | - b.AddUint16(uint16(curve)) | ||
969 | - } | ||
970 | - }) | ||
971 | - }) | ||
972 | - } | ||
973 | - if len(m.supportedPoints) > 0 { | ||
974 | - // RFC 4492, Section 5.1.2 | ||
975 | - b.AddUint16(extensionSupportedPoints) | ||
976 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
977 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
978 | - b.AddBytes(m.supportedPoints) | ||
979 | - }) | ||
980 | - }) | ||
981 | - } | ||
982 | - if m.ticketSupported { | ||
983 | - // RFC 5077, Section 3.2 | ||
984 | - b.AddUint16(extensionSessionTicket) | ||
985 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
986 | - b.AddBytes(m.sessionTicket) | ||
987 | - }) | ||
988 | - } | ||
989 | - if len(m.supportedSignatureAlgorithms) > 0 { | ||
990 | - // RFC 5246, Section 7.4.1.4.1 | ||
991 | - b.AddUint16(extensionSignatureAlgorithms) | ||
992 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
993 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
994 | - for _, sigAlgo := range m.supportedSignatureAlgorithms { | ||
995 | - b.AddUint16(uint16(sigAlgo)) | ||
996 | - } | ||
997 | - }) | ||
998 | - }) | ||
999 | - } | ||
1000 | - if len(m.supportedSignatureAlgorithmsCert) > 0 { | ||
1001 | - // RFC 8446, Section 4.2.3 | ||
1002 | - b.AddUint16(extensionSignatureAlgorithmsCert) | ||
1003 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1004 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1005 | - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { | ||
1006 | - b.AddUint16(uint16(sigAlgo)) | ||
1007 | - } | ||
1008 | - }) | ||
1009 | - }) | ||
1010 | - } | ||
1011 | - if m.secureRenegotiationSupported { | ||
1012 | - // RFC 5746, Section 3.2 | ||
1013 | - b.AddUint16(extensionRenegotiationInfo) | ||
1014 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1015 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1016 | - b.AddBytes(m.secureRenegotiation) | ||
1017 | - }) | ||
1018 | - }) | ||
1019 | - } | ||
1020 | - if len(m.alpnProtocols) > 0 { | ||
1021 | - // RFC 7301, Section 3.1 | ||
1022 | - b.AddUint16(extensionALPN) | ||
1023 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1024 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1025 | - for _, proto := range m.alpnProtocols { | ||
1026 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1027 | - b.AddBytes([]byte(proto)) | ||
1028 | - }) | ||
1029 | - } | ||
1030 | - }) | ||
1031 | - }) | ||
1032 | - } | ||
1033 | - if m.scts { | ||
1034 | - // RFC 6962, Section 3.3.1 | ||
1035 | - b.AddUint16(extensionSCT) | ||
1036 | - b.AddUint16(0) // empty extension_data | ||
1037 | - } | ||
1038 | - if len(m.supportedVersions) > 0 { | ||
1039 | - // RFC 8446, Section 4.2.1 | ||
1040 | - b.AddUint16(extensionSupportedVersions) | ||
1041 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1042 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1043 | - for _, vers := range m.supportedVersions { | ||
1044 | - b.AddUint16(vers) | ||
1045 | - } | ||
1046 | - }) | ||
1047 | - }) | ||
1048 | - } | ||
1049 | - if len(m.cookie) > 0 { | ||
1050 | - // RFC 8446, Section 4.2.2 | ||
1051 | - b.AddUint16(extensionCookie) | ||
1052 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1053 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1054 | - b.AddBytes(m.cookie) | ||
1055 | - }) | ||
1056 | - }) | ||
1057 | - } | ||
1058 | - if len(m.keyShares) > 0 { | ||
1059 | - // RFC 8446, Section 4.2.8 | ||
1060 | - b.AddUint16(extensionKeyShare) | ||
1061 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1062 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1063 | - for _, ks := range m.keyShares { | ||
1064 | - b.AddUint16(uint16(ks.group)) | ||
1065 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1066 | - b.AddBytes(ks.data) | ||
1067 | - }) | ||
1068 | - } | ||
1069 | - }) | ||
1070 | - }) | ||
1071 | - } | ||
1072 | - if m.earlyData { | ||
1073 | - // RFC 8446, Section 4.2.10 | ||
1074 | - b.AddUint16(extensionEarlyData) | ||
1075 | - b.AddUint16(0) // empty extension_data | ||
1076 | - } | ||
1077 | - if len(m.pskModes) > 0 { | ||
1078 | - // RFC 8446, Section 4.2.9 | ||
1079 | - b.AddUint16(extensionPSKModes) | ||
1080 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1081 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1082 | - b.AddBytes(m.pskModes) | ||
1083 | - }) | ||
1084 | - }) | ||
1085 | - } | ||
1086 | - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension | ||
1087 | - // RFC 8446, Section 4.2.11 | ||
1088 | - b.AddUint16(extensionPreSharedKey) | ||
1089 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1090 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1091 | - for _, psk := range m.pskIdentities { | ||
1092 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1093 | - b.AddBytes(psk.label) | ||
1094 | - }) | ||
1095 | - b.AddUint32(psk.obfuscatedTicketAge) | ||
1096 | - } | ||
1097 | - }) | ||
1098 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1099 | - for _, binder := range m.pskBinders { | ||
1100 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1101 | - b.AddBytes(binder) | ||
1102 | - }) | ||
1103 | - } | ||
1104 | - }) | ||
1105 | - }) | ||
1106 | - } | ||
1107 | - | ||
1108 | - extensionsPresent = len(b.BytesOrPanic()) > 2 | ||
1109 | - }) | ||
1110 | - | ||
1111 | - if !extensionsPresent { | ||
1112 | - *b = bWithoutExtensions | ||
1113 | - } | ||
1114 | - }) | ||
1115 | + if len(extBytes) > 0 { | ||
1116 | + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1117 | + b.AddBytes(extBytes) | ||
1118 | + }) | ||
1119 | + } | ||
1120 | + }) | ||
1121 | |||
1122 | - m.raw = b.BytesOrPanic() | ||
1123 | - return m.raw | ||
1124 | + m.raw, err = b.Bytes() | ||
1125 | + return m.raw, err | ||
1126 | } | ||
1127 | |||
1128 | // marshalWithoutBinders returns the ClientHello through the | ||
1129 | // PreSharedKeyExtension.identities field, according to RFC 8446, Section | ||
1130 | // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. | ||
1131 | -func (m *clientHelloMsg) marshalWithoutBinders() []byte { | ||
1132 | +func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { | ||
1133 | bindersLen := 2 // uint16 length prefix | ||
1134 | for _, binder := range m.pskBinders { | ||
1135 | bindersLen += 1 // uint8 length prefix | ||
1136 | bindersLen += len(binder) | ||
1137 | } | ||
1138 | |||
1139 | - fullMessage := m.marshal() | ||
1140 | - return fullMessage[:len(fullMessage)-bindersLen] | ||
1141 | + fullMessage, err := m.marshal() | ||
1142 | + if err != nil { | ||
1143 | + return nil, err | ||
1144 | + } | ||
1145 | + return fullMessage[:len(fullMessage)-bindersLen], nil | ||
1146 | } | ||
1147 | |||
1148 | // updateBinders updates the m.pskBinders field, if necessary updating the | ||
1149 | // cached marshaled representation. The supplied binders must have the same | ||
1150 | // length as the current m.pskBinders. | ||
1151 | -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { | ||
1152 | +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { | ||
1153 | if len(pskBinders) != len(m.pskBinders) { | ||
1154 | - panic("tls: internal error: pskBinders length mismatch") | ||
1155 | + return errors.New("tls: internal error: pskBinders length mismatch") | ||
1156 | } | ||
1157 | for i := range m.pskBinders { | ||
1158 | if len(pskBinders[i]) != len(m.pskBinders[i]) { | ||
1159 | - panic("tls: internal error: pskBinders length mismatch") | ||
1160 | + return errors.New("tls: internal error: pskBinders length mismatch") | ||
1161 | } | ||
1162 | } | ||
1163 | m.pskBinders = pskBinders | ||
1164 | if m.raw != nil { | ||
1165 | - lenWithoutBinders := len(m.marshalWithoutBinders()) | ||
1166 | + helloBytes, err := m.marshalWithoutBinders() | ||
1167 | + if err != nil { | ||
1168 | + return err | ||
1169 | + } | ||
1170 | + lenWithoutBinders := len(helloBytes) | ||
1171 | // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. | ||
1172 | b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) | ||
1173 | b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1174 | @@ -339,9 +346,11 @@ func (m *clientHelloMsg) updateBinders(p | ||
1175 | } | ||
1176 | }) | ||
1177 | if len(b.BytesOrPanic()) != len(m.raw) { | ||
1178 | - panic("tls: internal error: failed to update binders") | ||
1179 | + return errors.New("tls: internal error: failed to update binders") | ||
1180 | } | ||
1181 | } | ||
1182 | + | ||
1183 | + return nil | ||
1184 | } | ||
1185 | |||
1186 | func (m *clientHelloMsg) unmarshal(data []byte) bool { | ||
1187 | @@ -613,9 +622,98 @@ type serverHelloMsg struct { | ||
1188 | selectedGroup CurveID | ||
1189 | } | ||
1190 | |||
1191 | -func (m *serverHelloMsg) marshal() []byte { | ||
1192 | +func (m *serverHelloMsg) marshal() ([]byte, error) { | ||
1193 | if m.raw != nil { | ||
1194 | - return m.raw | ||
1195 | + return m.raw, nil | ||
1196 | + } | ||
1197 | + | ||
1198 | + var exts cryptobyte.Builder | ||
1199 | + if m.ocspStapling { | ||
1200 | + exts.AddUint16(extensionStatusRequest) | ||
1201 | + exts.AddUint16(0) // empty extension_data | ||
1202 | + } | ||
1203 | + if m.ticketSupported { | ||
1204 | + exts.AddUint16(extensionSessionTicket) | ||
1205 | + exts.AddUint16(0) // empty extension_data | ||
1206 | + } | ||
1207 | + if m.secureRenegotiationSupported { | ||
1208 | + exts.AddUint16(extensionRenegotiationInfo) | ||
1209 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1210 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1211 | + exts.AddBytes(m.secureRenegotiation) | ||
1212 | + }) | ||
1213 | + }) | ||
1214 | + } | ||
1215 | + if len(m.alpnProtocol) > 0 { | ||
1216 | + exts.AddUint16(extensionALPN) | ||
1217 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1218 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1219 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1220 | + exts.AddBytes([]byte(m.alpnProtocol)) | ||
1221 | + }) | ||
1222 | + }) | ||
1223 | + }) | ||
1224 | + } | ||
1225 | + if len(m.scts) > 0 { | ||
1226 | + exts.AddUint16(extensionSCT) | ||
1227 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1228 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1229 | + for _, sct := range m.scts { | ||
1230 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1231 | + exts.AddBytes(sct) | ||
1232 | + }) | ||
1233 | + } | ||
1234 | + }) | ||
1235 | + }) | ||
1236 | + } | ||
1237 | + if m.supportedVersion != 0 { | ||
1238 | + exts.AddUint16(extensionSupportedVersions) | ||
1239 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1240 | + exts.AddUint16(m.supportedVersion) | ||
1241 | + }) | ||
1242 | + } | ||
1243 | + if m.serverShare.group != 0 { | ||
1244 | + exts.AddUint16(extensionKeyShare) | ||
1245 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1246 | + exts.AddUint16(uint16(m.serverShare.group)) | ||
1247 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1248 | + exts.AddBytes(m.serverShare.data) | ||
1249 | + }) | ||
1250 | + }) | ||
1251 | + } | ||
1252 | + if m.selectedIdentityPresent { | ||
1253 | + exts.AddUint16(extensionPreSharedKey) | ||
1254 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1255 | + exts.AddUint16(m.selectedIdentity) | ||
1256 | + }) | ||
1257 | + } | ||
1258 | + | ||
1259 | + if len(m.cookie) > 0 { | ||
1260 | + exts.AddUint16(extensionCookie) | ||
1261 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1262 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1263 | + exts.AddBytes(m.cookie) | ||
1264 | + }) | ||
1265 | + }) | ||
1266 | + } | ||
1267 | + if m.selectedGroup != 0 { | ||
1268 | + exts.AddUint16(extensionKeyShare) | ||
1269 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1270 | + exts.AddUint16(uint16(m.selectedGroup)) | ||
1271 | + }) | ||
1272 | + } | ||
1273 | + if len(m.supportedPoints) > 0 { | ||
1274 | + exts.AddUint16(extensionSupportedPoints) | ||
1275 | + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1276 | + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { | ||
1277 | + exts.AddBytes(m.supportedPoints) | ||
1278 | + }) | ||
1279 | + }) | ||
1280 | + } | ||
1281 | + | ||
1282 | + extBytes, err := exts.Bytes() | ||
1283 | + if err != nil { | ||
1284 | + return nil, err | ||
1285 | } | ||
1286 | |||
1287 | var b cryptobyte.Builder | ||
1288 | @@ -629,104 +727,15 @@ func (m *serverHelloMsg) marshal() []byt | ||
1289 | b.AddUint16(m.cipherSuite) | ||
1290 | b.AddUint8(m.compressionMethod) | ||
1291 | |||
1292 | - // If extensions aren't present, omit them. | ||
1293 | - var extensionsPresent bool | ||
1294 | - bWithoutExtensions := *b | ||
1295 | - | ||
1296 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1297 | - if m.ocspStapling { | ||
1298 | - b.AddUint16(extensionStatusRequest) | ||
1299 | - b.AddUint16(0) // empty extension_data | ||
1300 | - } | ||
1301 | - if m.ticketSupported { | ||
1302 | - b.AddUint16(extensionSessionTicket) | ||
1303 | - b.AddUint16(0) // empty extension_data | ||
1304 | - } | ||
1305 | - if m.secureRenegotiationSupported { | ||
1306 | - b.AddUint16(extensionRenegotiationInfo) | ||
1307 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1308 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1309 | - b.AddBytes(m.secureRenegotiation) | ||
1310 | - }) | ||
1311 | - }) | ||
1312 | - } | ||
1313 | - if len(m.alpnProtocol) > 0 { | ||
1314 | - b.AddUint16(extensionALPN) | ||
1315 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1316 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1317 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1318 | - b.AddBytes([]byte(m.alpnProtocol)) | ||
1319 | - }) | ||
1320 | - }) | ||
1321 | - }) | ||
1322 | - } | ||
1323 | - if len(m.scts) > 0 { | ||
1324 | - b.AddUint16(extensionSCT) | ||
1325 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1326 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1327 | - for _, sct := range m.scts { | ||
1328 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1329 | - b.AddBytes(sct) | ||
1330 | - }) | ||
1331 | - } | ||
1332 | - }) | ||
1333 | - }) | ||
1334 | - } | ||
1335 | - if m.supportedVersion != 0 { | ||
1336 | - b.AddUint16(extensionSupportedVersions) | ||
1337 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1338 | - b.AddUint16(m.supportedVersion) | ||
1339 | - }) | ||
1340 | - } | ||
1341 | - if m.serverShare.group != 0 { | ||
1342 | - b.AddUint16(extensionKeyShare) | ||
1343 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1344 | - b.AddUint16(uint16(m.serverShare.group)) | ||
1345 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1346 | - b.AddBytes(m.serverShare.data) | ||
1347 | - }) | ||
1348 | - }) | ||
1349 | - } | ||
1350 | - if m.selectedIdentityPresent { | ||
1351 | - b.AddUint16(extensionPreSharedKey) | ||
1352 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1353 | - b.AddUint16(m.selectedIdentity) | ||
1354 | - }) | ||
1355 | - } | ||
1356 | - | ||
1357 | - if len(m.cookie) > 0 { | ||
1358 | - b.AddUint16(extensionCookie) | ||
1359 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1360 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1361 | - b.AddBytes(m.cookie) | ||
1362 | - }) | ||
1363 | - }) | ||
1364 | - } | ||
1365 | - if m.selectedGroup != 0 { | ||
1366 | - b.AddUint16(extensionKeyShare) | ||
1367 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1368 | - b.AddUint16(uint16(m.selectedGroup)) | ||
1369 | - }) | ||
1370 | - } | ||
1371 | - if len(m.supportedPoints) > 0 { | ||
1372 | - b.AddUint16(extensionSupportedPoints) | ||
1373 | - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1374 | - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1375 | - b.AddBytes(m.supportedPoints) | ||
1376 | - }) | ||
1377 | - }) | ||
1378 | - } | ||
1379 | - | ||
1380 | - extensionsPresent = len(b.BytesOrPanic()) > 2 | ||
1381 | - }) | ||
1382 | - | ||
1383 | - if !extensionsPresent { | ||
1384 | - *b = bWithoutExtensions | ||
1385 | + if len(extBytes) > 0 { | ||
1386 | + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { | ||
1387 | + b.AddBytes(extBytes) | ||
1388 | + }) | ||
1389 | } | ||
1390 | }) | ||
1391 | |||
1392 | - m.raw = b.BytesOrPanic() | ||
1393 | - return m.raw | ||
1394 | + m.raw, err = b.Bytes() | ||
1395 | + return m.raw, err | ||
1396 | } | ||
1397 | |||
1398 | func (m *serverHelloMsg) unmarshal(data []byte) bool { | ||
1399 | @@ -844,9 +853,9 @@ type encryptedExtensionsMsg struct { | ||
1400 | alpnProtocol string | ||
1401 | } | ||
1402 | |||
1403 | -func (m *encryptedExtensionsMsg) marshal() []byte { | ||
1404 | +func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { | ||
1405 | if m.raw != nil { | ||
1406 | - return m.raw | ||
1407 | + return m.raw, nil | ||
1408 | } | ||
1409 | |||
1410 | var b cryptobyte.Builder | ||
1411 | @@ -866,8 +875,9 @@ func (m *encryptedExtensionsMsg) marshal | ||
1412 | }) | ||
1413 | }) | ||
1414 | |||
1415 | - m.raw = b.BytesOrPanic() | ||
1416 | - return m.raw | ||
1417 | + var err error | ||
1418 | + m.raw, err = b.Bytes() | ||
1419 | + return m.raw, err | ||
1420 | } | ||
1421 | |||
1422 | func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { | ||
1423 | @@ -915,10 +925,10 @@ func (m *encryptedExtensionsMsg) unmarsh | ||
1424 | |||
1425 | type endOfEarlyDataMsg struct{} | ||
1426 | |||
1427 | -func (m *endOfEarlyDataMsg) marshal() []byte { | ||
1428 | +func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { | ||
1429 | x := make([]byte, 4) | ||
1430 | x[0] = typeEndOfEarlyData | ||
1431 | - return x | ||
1432 | + return x, nil | ||
1433 | } | ||
1434 | |||
1435 | func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { | ||
1436 | @@ -930,9 +940,9 @@ type keyUpdateMsg struct { | ||
1437 | updateRequested bool | ||
1438 | } | ||
1439 | |||
1440 | -func (m *keyUpdateMsg) marshal() []byte { | ||
1441 | +func (m *keyUpdateMsg) marshal() ([]byte, error) { | ||
1442 | if m.raw != nil { | ||
1443 | - return m.raw | ||
1444 | + return m.raw, nil | ||
1445 | } | ||
1446 | |||
1447 | var b cryptobyte.Builder | ||
1448 | @@ -945,8 +955,9 @@ func (m *keyUpdateMsg) marshal() []byte | ||
1449 | } | ||
1450 | }) | ||
1451 | |||
1452 | - m.raw = b.BytesOrPanic() | ||
1453 | - return m.raw | ||
1454 | + var err error | ||
1455 | + m.raw, err = b.Bytes() | ||
1456 | + return m.raw, err | ||
1457 | } | ||
1458 | |||
1459 | func (m *keyUpdateMsg) unmarshal(data []byte) bool { | ||
1460 | @@ -978,9 +989,9 @@ type newSessionTicketMsgTLS13 struct { | ||
1461 | maxEarlyData uint32 | ||
1462 | } | ||
1463 | |||
1464 | -func (m *newSessionTicketMsgTLS13) marshal() []byte { | ||
1465 | +func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { | ||
1466 | if m.raw != nil { | ||
1467 | - return m.raw | ||
1468 | + return m.raw, nil | ||
1469 | } | ||
1470 | |||
1471 | var b cryptobyte.Builder | ||
1472 | @@ -1005,8 +1016,9 @@ func (m *newSessionTicketMsgTLS13) marsh | ||
1473 | }) | ||
1474 | }) | ||
1475 | |||
1476 | - m.raw = b.BytesOrPanic() | ||
1477 | - return m.raw | ||
1478 | + var err error | ||
1479 | + m.raw, err = b.Bytes() | ||
1480 | + return m.raw, err | ||
1481 | } | ||
1482 | |||
1483 | func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { | ||
1484 | @@ -1059,9 +1071,9 @@ type certificateRequestMsgTLS13 struct { | ||
1485 | certificateAuthorities [][]byte | ||
1486 | } | ||
1487 | |||
1488 | -func (m *certificateRequestMsgTLS13) marshal() []byte { | ||
1489 | +func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { | ||
1490 | if m.raw != nil { | ||
1491 | - return m.raw | ||
1492 | + return m.raw, nil | ||
1493 | } | ||
1494 | |||
1495 | var b cryptobyte.Builder | ||
1496 | @@ -1120,8 +1132,9 @@ func (m *certificateRequestMsgTLS13) mar | ||
1497 | }) | ||
1498 | }) | ||
1499 | |||
1500 | - m.raw = b.BytesOrPanic() | ||
1501 | - return m.raw | ||
1502 | + var err error | ||
1503 | + m.raw, err = b.Bytes() | ||
1504 | + return m.raw, err | ||
1505 | } | ||
1506 | |||
1507 | func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { | ||
1508 | @@ -1205,9 +1218,9 @@ type certificateMsg struct { | ||
1509 | certificates [][]byte | ||
1510 | } | ||
1511 | |||
1512 | -func (m *certificateMsg) marshal() (x []byte) { | ||
1513 | +func (m *certificateMsg) marshal() ([]byte, error) { | ||
1514 | if m.raw != nil { | ||
1515 | - return m.raw | ||
1516 | + return m.raw, nil | ||
1517 | } | ||
1518 | |||
1519 | var i int | ||
1520 | @@ -1216,7 +1229,7 @@ func (m *certificateMsg) marshal() (x [] | ||
1521 | } | ||
1522 | |||
1523 | length := 3 + 3*len(m.certificates) + i | ||
1524 | - x = make([]byte, 4+length) | ||
1525 | + x := make([]byte, 4+length) | ||
1526 | x[0] = typeCertificate | ||
1527 | x[1] = uint8(length >> 16) | ||
1528 | x[2] = uint8(length >> 8) | ||
1529 | @@ -1237,7 +1250,7 @@ func (m *certificateMsg) marshal() (x [] | ||
1530 | } | ||
1531 | |||
1532 | m.raw = x | ||
1533 | - return | ||
1534 | + return m.raw, nil | ||
1535 | } | ||
1536 | |||
1537 | func (m *certificateMsg) unmarshal(data []byte) bool { | ||
1538 | @@ -1284,9 +1297,9 @@ type certificateMsgTLS13 struct { | ||
1539 | scts bool | ||
1540 | } | ||
1541 | |||
1542 | -func (m *certificateMsgTLS13) marshal() []byte { | ||
1543 | +func (m *certificateMsgTLS13) marshal() ([]byte, error) { | ||
1544 | if m.raw != nil { | ||
1545 | - return m.raw | ||
1546 | + return m.raw, nil | ||
1547 | } | ||
1548 | |||
1549 | var b cryptobyte.Builder | ||
1550 | @@ -1304,8 +1317,9 @@ func (m *certificateMsgTLS13) marshal() | ||
1551 | marshalCertificate(b, certificate) | ||
1552 | }) | ||
1553 | |||
1554 | - m.raw = b.BytesOrPanic() | ||
1555 | - return m.raw | ||
1556 | + var err error | ||
1557 | + m.raw, err = b.Bytes() | ||
1558 | + return m.raw, err | ||
1559 | } | ||
1560 | |||
1561 | func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { | ||
1562 | @@ -1428,9 +1442,9 @@ type serverKeyExchangeMsg struct { | ||
1563 | key []byte | ||
1564 | } | ||
1565 | |||
1566 | -func (m *serverKeyExchangeMsg) marshal() []byte { | ||
1567 | +func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { | ||
1568 | if m.raw != nil { | ||
1569 | - return m.raw | ||
1570 | + return m.raw, nil | ||
1571 | } | ||
1572 | length := len(m.key) | ||
1573 | x := make([]byte, length+4) | ||
1574 | @@ -1441,7 +1455,7 @@ func (m *serverKeyExchangeMsg) marshal() | ||
1575 | copy(x[4:], m.key) | ||
1576 | |||
1577 | m.raw = x | ||
1578 | - return x | ||
1579 | + return x, nil | ||
1580 | } | ||
1581 | |||
1582 | func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { | ||
1583 | @@ -1458,9 +1472,9 @@ type certificateStatusMsg struct { | ||
1584 | response []byte | ||
1585 | } | ||
1586 | |||
1587 | -func (m *certificateStatusMsg) marshal() []byte { | ||
1588 | +func (m *certificateStatusMsg) marshal() ([]byte, error) { | ||
1589 | if m.raw != nil { | ||
1590 | - return m.raw | ||
1591 | + return m.raw, nil | ||
1592 | } | ||
1593 | |||
1594 | var b cryptobyte.Builder | ||
1595 | @@ -1472,8 +1486,9 @@ func (m *certificateStatusMsg) marshal() | ||
1596 | }) | ||
1597 | }) | ||
1598 | |||
1599 | - m.raw = b.BytesOrPanic() | ||
1600 | - return m.raw | ||
1601 | + var err error | ||
1602 | + m.raw, err = b.Bytes() | ||
1603 | + return m.raw, err | ||
1604 | } | ||
1605 | |||
1606 | func (m *certificateStatusMsg) unmarshal(data []byte) bool { | ||
1607 | @@ -1492,10 +1507,10 @@ func (m *certificateStatusMsg) unmarshal | ||
1608 | |||
1609 | type serverHelloDoneMsg struct{} | ||
1610 | |||
1611 | -func (m *serverHelloDoneMsg) marshal() []byte { | ||
1612 | +func (m *serverHelloDoneMsg) marshal() ([]byte, error) { | ||
1613 | x := make([]byte, 4) | ||
1614 | x[0] = typeServerHelloDone | ||
1615 | - return x | ||
1616 | + return x, nil | ||
1617 | } | ||
1618 | |||
1619 | func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { | ||
1620 | @@ -1507,9 +1522,9 @@ type clientKeyExchangeMsg struct { | ||
1621 | ciphertext []byte | ||
1622 | } | ||
1623 | |||
1624 | -func (m *clientKeyExchangeMsg) marshal() []byte { | ||
1625 | +func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { | ||
1626 | if m.raw != nil { | ||
1627 | - return m.raw | ||
1628 | + return m.raw, nil | ||
1629 | } | ||
1630 | length := len(m.ciphertext) | ||
1631 | x := make([]byte, length+4) | ||
1632 | @@ -1520,7 +1535,7 @@ func (m *clientKeyExchangeMsg) marshal() | ||
1633 | copy(x[4:], m.ciphertext) | ||
1634 | |||
1635 | m.raw = x | ||
1636 | - return x | ||
1637 | + return x, nil | ||
1638 | } | ||
1639 | |||
1640 | func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { | ||
1641 | @@ -1541,9 +1556,9 @@ type finishedMsg struct { | ||
1642 | verifyData []byte | ||
1643 | } | ||
1644 | |||
1645 | -func (m *finishedMsg) marshal() []byte { | ||
1646 | +func (m *finishedMsg) marshal() ([]byte, error) { | ||
1647 | if m.raw != nil { | ||
1648 | - return m.raw | ||
1649 | + return m.raw, nil | ||
1650 | } | ||
1651 | |||
1652 | var b cryptobyte.Builder | ||
1653 | @@ -1552,8 +1567,9 @@ func (m *finishedMsg) marshal() []byte { | ||
1654 | b.AddBytes(m.verifyData) | ||
1655 | }) | ||
1656 | |||
1657 | - m.raw = b.BytesOrPanic() | ||
1658 | - return m.raw | ||
1659 | + var err error | ||
1660 | + m.raw, err = b.Bytes() | ||
1661 | + return m.raw, err | ||
1662 | } | ||
1663 | |||
1664 | func (m *finishedMsg) unmarshal(data []byte) bool { | ||
1665 | @@ -1575,9 +1591,9 @@ type certificateRequestMsg struct { | ||
1666 | certificateAuthorities [][]byte | ||
1667 | } | ||
1668 | |||
1669 | -func (m *certificateRequestMsg) marshal() (x []byte) { | ||
1670 | +func (m *certificateRequestMsg) marshal() ([]byte, error) { | ||
1671 | if m.raw != nil { | ||
1672 | - return m.raw | ||
1673 | + return m.raw, nil | ||
1674 | } | ||
1675 | |||
1676 | // See RFC 4346, Section 7.4.4. | ||
1677 | @@ -1592,7 +1608,7 @@ func (m *certificateRequestMsg) marshal( | ||
1678 | length += 2 + 2*len(m.supportedSignatureAlgorithms) | ||
1679 | } | ||
1680 | |||
1681 | - x = make([]byte, 4+length) | ||
1682 | + x := make([]byte, 4+length) | ||
1683 | x[0] = typeCertificateRequest | ||
1684 | x[1] = uint8(length >> 16) | ||
1685 | x[2] = uint8(length >> 8) | ||
1686 | @@ -1627,7 +1643,7 @@ func (m *certificateRequestMsg) marshal( | ||
1687 | } | ||
1688 | |||
1689 | m.raw = x | ||
1690 | - return | ||
1691 | + return m.raw, nil | ||
1692 | } | ||
1693 | |||
1694 | func (m *certificateRequestMsg) unmarshal(data []byte) bool { | ||
1695 | @@ -1713,9 +1729,9 @@ type certificateVerifyMsg struct { | ||
1696 | signature []byte | ||
1697 | } | ||
1698 | |||
1699 | -func (m *certificateVerifyMsg) marshal() (x []byte) { | ||
1700 | +func (m *certificateVerifyMsg) marshal() ([]byte, error) { | ||
1701 | if m.raw != nil { | ||
1702 | - return m.raw | ||
1703 | + return m.raw, nil | ||
1704 | } | ||
1705 | |||
1706 | var b cryptobyte.Builder | ||
1707 | @@ -1729,8 +1745,9 @@ func (m *certificateVerifyMsg) marshal() | ||
1708 | }) | ||
1709 | }) | ||
1710 | |||
1711 | - m.raw = b.BytesOrPanic() | ||
1712 | - return m.raw | ||
1713 | + var err error | ||
1714 | + m.raw, err = b.Bytes() | ||
1715 | + return m.raw, err | ||
1716 | } | ||
1717 | |||
1718 | func (m *certificateVerifyMsg) unmarshal(data []byte) bool { | ||
1719 | @@ -1753,15 +1770,15 @@ type newSessionTicketMsg struct { | ||
1720 | ticket []byte | ||
1721 | } | ||
1722 | |||
1723 | -func (m *newSessionTicketMsg) marshal() (x []byte) { | ||
1724 | +func (m *newSessionTicketMsg) marshal() ([]byte, error) { | ||
1725 | if m.raw != nil { | ||
1726 | - return m.raw | ||
1727 | + return m.raw, nil | ||
1728 | } | ||
1729 | |||
1730 | // See RFC 5077, Section 3.3. | ||
1731 | ticketLen := len(m.ticket) | ||
1732 | length := 2 + 4 + ticketLen | ||
1733 | - x = make([]byte, 4+length) | ||
1734 | + x := make([]byte, 4+length) | ||
1735 | x[0] = typeNewSessionTicket | ||
1736 | x[1] = uint8(length >> 16) | ||
1737 | x[2] = uint8(length >> 8) | ||
1738 | @@ -1772,7 +1789,7 @@ func (m *newSessionTicketMsg) marshal() | ||
1739 | |||
1740 | m.raw = x | ||
1741 | |||
1742 | - return | ||
1743 | + return m.raw, nil | ||
1744 | } | ||
1745 | |||
1746 | func (m *newSessionTicketMsg) unmarshal(data []byte) bool { | ||
1747 | @@ -1800,10 +1817,25 @@ func (m *newSessionTicketMsg) unmarshal( | ||
1748 | type helloRequestMsg struct { | ||
1749 | } | ||
1750 | |||
1751 | -func (*helloRequestMsg) marshal() []byte { | ||
1752 | - return []byte{typeHelloRequest, 0, 0, 0} | ||
1753 | +func (*helloRequestMsg) marshal() ([]byte, error) { | ||
1754 | + return []byte{typeHelloRequest, 0, 0, 0}, nil | ||
1755 | } | ||
1756 | |||
1757 | func (*helloRequestMsg) unmarshal(data []byte) bool { | ||
1758 | return len(data) == 4 | ||
1759 | } | ||
1760 | + | ||
1761 | +type transcriptHash interface { | ||
1762 | + Write([]byte) (int, error) | ||
1763 | +} | ||
1764 | + | ||
1765 | +// transcriptMsg is a helper used to marshal and hash messages which typically | ||
1766 | +// are not written to the wire, and as such aren't hashed during Conn.writeRecord. | ||
1767 | +func transcriptMsg(msg handshakeMessage, h transcriptHash) error { | ||
1768 | + data, err := msg.marshal() | ||
1769 | + if err != nil { | ||
1770 | + return err | ||
1771 | + } | ||
1772 | + h.Write(data) | ||
1773 | + return nil | ||
1774 | +} | ||
1775 | --- go.orig/src/crypto/tls/handshake_messages_test.go | ||
1776 | +++ go/src/crypto/tls/handshake_messages_test.go | ||
1777 | @@ -37,6 +37,15 @@ var tests = []interface{}{ | ||
1778 | &certificateMsgTLS13{}, | ||
1779 | } | ||
1780 | |||
1781 | +func mustMarshal(t *testing.T, msg handshakeMessage) []byte { | ||
1782 | + t.Helper() | ||
1783 | + b, err := msg.marshal() | ||
1784 | + if err != nil { | ||
1785 | + t.Fatal(err) | ||
1786 | + } | ||
1787 | + return b | ||
1788 | +} | ||
1789 | + | ||
1790 | func TestMarshalUnmarshal(t *testing.T) { | ||
1791 | rand := rand.New(rand.NewSource(time.Now().UnixNano())) | ||
1792 | |||
1793 | @@ -55,7 +64,7 @@ func TestMarshalUnmarshal(t *testing.T) | ||
1794 | } | ||
1795 | |||
1796 | m1 := v.Interface().(handshakeMessage) | ||
1797 | - marshaled := m1.marshal() | ||
1798 | + marshaled := mustMarshal(t, m1) | ||
1799 | m2 := iface.(handshakeMessage) | ||
1800 | if !m2.unmarshal(marshaled) { | ||
1801 | t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) | ||
1802 | @@ -408,12 +417,12 @@ func TestRejectEmptySCTList(t *testing.T | ||
1803 | |||
1804 | var random [32]byte | ||
1805 | sct := []byte{0x42, 0x42, 0x42, 0x42} | ||
1806 | - serverHello := serverHelloMsg{ | ||
1807 | + serverHello := &serverHelloMsg{ | ||
1808 | vers: VersionTLS12, | ||
1809 | random: random[:], | ||
1810 | scts: [][]byte{sct}, | ||
1811 | } | ||
1812 | - serverHelloBytes := serverHello.marshal() | ||
1813 | + serverHelloBytes := mustMarshal(t, serverHello) | ||
1814 | |||
1815 | var serverHelloCopy serverHelloMsg | ||
1816 | if !serverHelloCopy.unmarshal(serverHelloBytes) { | ||
1817 | @@ -451,12 +460,12 @@ func TestRejectEmptySCT(t *testing.T) { | ||
1818 | // not be zero length. | ||
1819 | |||
1820 | var random [32]byte | ||
1821 | - serverHello := serverHelloMsg{ | ||
1822 | + serverHello := &serverHelloMsg{ | ||
1823 | vers: VersionTLS12, | ||
1824 | random: random[:], | ||
1825 | scts: [][]byte{nil}, | ||
1826 | } | ||
1827 | - serverHelloBytes := serverHello.marshal() | ||
1828 | + serverHelloBytes := mustMarshal(t, serverHello) | ||
1829 | |||
1830 | var serverHelloCopy serverHelloMsg | ||
1831 | if serverHelloCopy.unmarshal(serverHelloBytes) { | ||
1832 | --- go.orig/src/crypto/tls/handshake_server.go | ||
1833 | +++ go/src/crypto/tls/handshake_server.go | ||
1834 | @@ -129,7 +129,9 @@ func (hs *serverHandshakeState) handshak | ||
1835 | |||
1836 | // readClientHello reads a ClientHello message and selects the protocol version. | ||
1837 | func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { | ||
1838 | - msg, err := c.readHandshake() | ||
1839 | + // clientHelloMsg is included in the transcript, but we haven't initialized | ||
1840 | + // it yet. The respective handshake functions will record it themselves. | ||
1841 | + msg, err := c.readHandshake(nil) | ||
1842 | if err != nil { | ||
1843 | return nil, err | ||
1844 | } | ||
1845 | @@ -456,9 +458,10 @@ func (hs *serverHandshakeState) doResume | ||
1846 | hs.hello.ticketSupported = hs.sessionState.usedOldKey | ||
1847 | hs.finishedHash = newFinishedHash(c.vers, hs.suite) | ||
1848 | hs.finishedHash.discardHandshakeBuffer() | ||
1849 | - hs.finishedHash.Write(hs.clientHello.marshal()) | ||
1850 | - hs.finishedHash.Write(hs.hello.marshal()) | ||
1851 | - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { | ||
1852 | + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { | ||
1853 | + return err | ||
1854 | + } | ||
1855 | + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { | ||
1856 | return err | ||
1857 | } | ||
1858 | |||
1859 | @@ -496,24 +499,23 @@ func (hs *serverHandshakeState) doFullHa | ||
1860 | // certificates won't be used. | ||
1861 | hs.finishedHash.discardHandshakeBuffer() | ||
1862 | } | ||
1863 | - hs.finishedHash.Write(hs.clientHello.marshal()) | ||
1864 | - hs.finishedHash.Write(hs.hello.marshal()) | ||
1865 | - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { | ||
1866 | + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { | ||
1867 | + return err | ||
1868 | + } | ||
1869 | + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { | ||
1870 | return err | ||
1871 | } | ||
1872 | |||
1873 | certMsg := new(certificateMsg) | ||
1874 | certMsg.certificates = hs.cert.Certificate | ||
1875 | - hs.finishedHash.Write(certMsg.marshal()) | ||
1876 | - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { | ||
1877 | + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { | ||
1878 | return err | ||
1879 | } | ||
1880 | |||
1881 | if hs.hello.ocspStapling { | ||
1882 | certStatus := new(certificateStatusMsg) | ||
1883 | certStatus.response = hs.cert.OCSPStaple | ||
1884 | - hs.finishedHash.Write(certStatus.marshal()) | ||
1885 | - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { | ||
1886 | + if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { | ||
1887 | return err | ||
1888 | } | ||
1889 | } | ||
1890 | @@ -525,8 +527,7 @@ func (hs *serverHandshakeState) doFullHa | ||
1891 | return err | ||
1892 | } | ||
1893 | if skx != nil { | ||
1894 | - hs.finishedHash.Write(skx.marshal()) | ||
1895 | - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { | ||
1896 | + if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { | ||
1897 | return err | ||
1898 | } | ||
1899 | } | ||
1900 | @@ -552,15 +553,13 @@ func (hs *serverHandshakeState) doFullHa | ||
1901 | if c.config.ClientCAs != nil { | ||
1902 | certReq.certificateAuthorities = c.config.ClientCAs.Subjects() | ||
1903 | } | ||
1904 | - hs.finishedHash.Write(certReq.marshal()) | ||
1905 | - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { | ||
1906 | + if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil { | ||
1907 | return err | ||
1908 | } | ||
1909 | } | ||
1910 | |||
1911 | helloDone := new(serverHelloDoneMsg) | ||
1912 | - hs.finishedHash.Write(helloDone.marshal()) | ||
1913 | - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { | ||
1914 | + if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { | ||
1915 | return err | ||
1916 | } | ||
1917 | |||
1918 | @@ -570,7 +569,7 @@ func (hs *serverHandshakeState) doFullHa | ||
1919 | |||
1920 | var pub crypto.PublicKey // public key for client auth, if any | ||
1921 | |||
1922 | - msg, err := c.readHandshake() | ||
1923 | + msg, err := c.readHandshake(&hs.finishedHash) | ||
1924 | if err != nil { | ||
1925 | return err | ||
1926 | } | ||
1927 | @@ -583,7 +582,6 @@ func (hs *serverHandshakeState) doFullHa | ||
1928 | c.sendAlert(alertUnexpectedMessage) | ||
1929 | return unexpectedMessageError(certMsg, msg) | ||
1930 | } | ||
1931 | - hs.finishedHash.Write(certMsg.marshal()) | ||
1932 | |||
1933 | if err := c.processCertsFromClient(Certificate{ | ||
1934 | Certificate: certMsg.certificates, | ||
1935 | @@ -594,7 +592,7 @@ func (hs *serverHandshakeState) doFullHa | ||
1936 | pub = c.peerCertificates[0].PublicKey | ||
1937 | } | ||
1938 | |||
1939 | - msg, err = c.readHandshake() | ||
1940 | + msg, err = c.readHandshake(&hs.finishedHash) | ||
1941 | if err != nil { | ||
1942 | return err | ||
1943 | } | ||
1944 | @@ -612,7 +610,6 @@ func (hs *serverHandshakeState) doFullHa | ||
1945 | c.sendAlert(alertUnexpectedMessage) | ||
1946 | return unexpectedMessageError(ckx, msg) | ||
1947 | } | ||
1948 | - hs.finishedHash.Write(ckx.marshal()) | ||
1949 | |||
1950 | preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) | ||
1951 | if err != nil { | ||
1952 | @@ -632,7 +629,10 @@ func (hs *serverHandshakeState) doFullHa | ||
1953 | // to the client's certificate. This allows us to verify that the client is in | ||
1954 | // possession of the private key of the certificate. | ||
1955 | if len(c.peerCertificates) > 0 { | ||
1956 | - msg, err = c.readHandshake() | ||
1957 | + // certificateVerifyMsg is included in the transcript, but not until | ||
1958 | + // after we verify the handshake signature, since the state before | ||
1959 | + // this message was sent is used. | ||
1960 | + msg, err = c.readHandshake(nil) | ||
1961 | if err != nil { | ||
1962 | return err | ||
1963 | } | ||
1964 | @@ -667,7 +667,9 @@ func (hs *serverHandshakeState) doFullHa | ||
1965 | return errors.New("tls: invalid signature by the client certificate: " + err.Error()) | ||
1966 | } | ||
1967 | |||
1968 | - hs.finishedHash.Write(certVerify.marshal()) | ||
1969 | + if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { | ||
1970 | + return err | ||
1971 | + } | ||
1972 | } | ||
1973 | |||
1974 | hs.finishedHash.discardHandshakeBuffer() | ||
1975 | @@ -707,7 +709,10 @@ func (hs *serverHandshakeState) readFini | ||
1976 | return err | ||
1977 | } | ||
1978 | |||
1979 | - msg, err := c.readHandshake() | ||
1980 | + // finishedMsg is included in the transcript, but not until after we | ||
1981 | + // check the client version, since the state before this message was | ||
1982 | + // sent is used during verification. | ||
1983 | + msg, err := c.readHandshake(nil) | ||
1984 | if err != nil { | ||
1985 | return err | ||
1986 | } | ||
1987 | @@ -724,7 +729,10 @@ func (hs *serverHandshakeState) readFini | ||
1988 | return errors.New("tls: client's Finished message is incorrect") | ||
1989 | } | ||
1990 | |||
1991 | - hs.finishedHash.Write(clientFinished.marshal()) | ||
1992 | + if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { | ||
1993 | + return err | ||
1994 | + } | ||
1995 | + | ||
1996 | copy(out, verify) | ||
1997 | return nil | ||
1998 | } | ||
1999 | @@ -758,14 +766,16 @@ func (hs *serverHandshakeState) sendSess | ||
2000 | masterSecret: hs.masterSecret, | ||
2001 | certificates: certsFromClient, | ||
2002 | } | ||
2003 | - var err error | ||
2004 | - m.ticket, err = c.encryptTicket(state.marshal()) | ||
2005 | + stateBytes, err := state.marshal() | ||
2006 | + if err != nil { | ||
2007 | + return err | ||
2008 | + } | ||
2009 | + m.ticket, err = c.encryptTicket(stateBytes) | ||
2010 | if err != nil { | ||
2011 | return err | ||
2012 | } | ||
2013 | |||
2014 | - hs.finishedHash.Write(m.marshal()) | ||
2015 | - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { | ||
2016 | + if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { | ||
2017 | return err | ||
2018 | } | ||
2019 | |||
2020 | @@ -775,14 +785,13 @@ func (hs *serverHandshakeState) sendSess | ||
2021 | func (hs *serverHandshakeState) sendFinished(out []byte) error { | ||
2022 | c := hs.c | ||
2023 | |||
2024 | - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { | ||
2025 | + if err := c.writeChangeCipherRecord(); err != nil { | ||
2026 | return err | ||
2027 | } | ||
2028 | |||
2029 | finished := new(finishedMsg) | ||
2030 | finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) | ||
2031 | - hs.finishedHash.Write(finished.marshal()) | ||
2032 | - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { | ||
2033 | + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { | ||
2034 | return err | ||
2035 | } | ||
2036 | |||
2037 | --- go.orig/src/crypto/tls/handshake_server_test.go | ||
2038 | +++ go/src/crypto/tls/handshake_server_test.go | ||
2039 | @@ -30,6 +30,13 @@ func testClientHello(t *testing.T, serve | ||
2040 | testClientHelloFailure(t, serverConfig, m, "") | ||
2041 | } | ||
2042 | |||
2043 | +// testFatal is a hack to prevent the compiler from complaining that there is a | ||
2044 | +// call to t.Fatal from a non-test goroutine | ||
2045 | +func testFatal(t *testing.T, err error) { | ||
2046 | + t.Helper() | ||
2047 | + t.Fatal(err) | ||
2048 | +} | ||
2049 | + | ||
2050 | func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { | ||
2051 | c, s := localPipe(t) | ||
2052 | go func() { | ||
2053 | @@ -37,7 +44,9 @@ func testClientHelloFailure(t *testing.T | ||
2054 | if ch, ok := m.(*clientHelloMsg); ok { | ||
2055 | cli.vers = ch.vers | ||
2056 | } | ||
2057 | - cli.writeRecord(recordTypeHandshake, m.marshal()) | ||
2058 | + if _, err := cli.writeHandshakeRecord(m, nil); err != nil { | ||
2059 | + testFatal(t, err) | ||
2060 | + } | ||
2061 | c.Close() | ||
2062 | }() | ||
2063 | ctx := context.Background() | ||
2064 | @@ -194,7 +203,9 @@ func TestRenegotiationExtension(t *testi | ||
2065 | go func() { | ||
2066 | cli := Client(c, testConfig) | ||
2067 | cli.vers = clientHello.vers | ||
2068 | - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) | ||
2069 | + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { | ||
2070 | + testFatal(t, err) | ||
2071 | + } | ||
2072 | |||
2073 | buf := make([]byte, 1024) | ||
2074 | n, err := c.Read(buf) | ||
2075 | @@ -253,8 +264,10 @@ func TestTLS12OnlyCipherSuites(t *testin | ||
2076 | go func() { | ||
2077 | cli := Client(c, testConfig) | ||
2078 | cli.vers = clientHello.vers | ||
2079 | - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) | ||
2080 | - reply, err := cli.readHandshake() | ||
2081 | + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { | ||
2082 | + testFatal(t, err) | ||
2083 | + } | ||
2084 | + reply, err := cli.readHandshake(nil) | ||
2085 | c.Close() | ||
2086 | if err != nil { | ||
2087 | replyChan <- err | ||
2088 | @@ -308,8 +321,10 @@ func TestTLSPointFormats(t *testing.T) { | ||
2089 | go func() { | ||
2090 | cli := Client(c, testConfig) | ||
2091 | cli.vers = clientHello.vers | ||
2092 | - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) | ||
2093 | - reply, err := cli.readHandshake() | ||
2094 | + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { | ||
2095 | + testFatal(t, err) | ||
2096 | + } | ||
2097 | + reply, err := cli.readHandshake(nil) | ||
2098 | c.Close() | ||
2099 | if err != nil { | ||
2100 | replyChan <- err | ||
2101 | @@ -1425,7 +1440,9 @@ func TestSNIGivenOnFailure(t *testing.T) | ||
2102 | go func() { | ||
2103 | cli := Client(c, testConfig) | ||
2104 | cli.vers = clientHello.vers | ||
2105 | - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) | ||
2106 | + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { | ||
2107 | + testFatal(t, err) | ||
2108 | + } | ||
2109 | c.Close() | ||
2110 | }() | ||
2111 | conn := Server(s, serverConfig) | ||
2112 | --- go.orig/src/crypto/tls/handshake_server_tls13.go | ||
2113 | +++ go/src/crypto/tls/handshake_server_tls13.go | ||
2114 | @@ -298,7 +298,12 @@ func (hs *serverHandshakeStateTLS13) che | ||
2115 | c.sendAlert(alertInternalError) | ||
2116 | return errors.New("tls: internal error: failed to clone hash") | ||
2117 | } | ||
2118 | - transcript.Write(hs.clientHello.marshalWithoutBinders()) | ||
2119 | + clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() | ||
2120 | + if err != nil { | ||
2121 | + c.sendAlert(alertInternalError) | ||
2122 | + return err | ||
2123 | + } | ||
2124 | + transcript.Write(clientHelloBytes) | ||
2125 | pskBinder := hs.suite.finishedHash(binderKey, transcript) | ||
2126 | if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { | ||
2127 | c.sendAlert(alertDecryptError) | ||
2128 | @@ -389,8 +394,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2129 | } | ||
2130 | hs.sentDummyCCS = true | ||
2131 | |||
2132 | - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) | ||
2133 | - return err | ||
2134 | + return hs.c.writeChangeCipherRecord() | ||
2135 | } | ||
2136 | |||
2137 | func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { | ||
2138 | @@ -398,7 +402,9 @@ func (hs *serverHandshakeStateTLS13) doH | ||
2139 | |||
2140 | // The first ClientHello gets double-hashed into the transcript upon a | ||
2141 | // HelloRetryRequest. See RFC 8446, Section 4.4.1. | ||
2142 | - hs.transcript.Write(hs.clientHello.marshal()) | ||
2143 | + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { | ||
2144 | + return err | ||
2145 | + } | ||
2146 | chHash := hs.transcript.Sum(nil) | ||
2147 | hs.transcript.Reset() | ||
2148 | hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) | ||
2149 | @@ -414,8 +420,7 @@ func (hs *serverHandshakeStateTLS13) doH | ||
2150 | selectedGroup: selectedGroup, | ||
2151 | } | ||
2152 | |||
2153 | - hs.transcript.Write(helloRetryRequest.marshal()) | ||
2154 | - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { | ||
2155 | + if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { | ||
2156 | return err | ||
2157 | } | ||
2158 | |||
2159 | @@ -423,7 +428,8 @@ func (hs *serverHandshakeStateTLS13) doH | ||
2160 | return err | ||
2161 | } | ||
2162 | |||
2163 | - msg, err := c.readHandshake() | ||
2164 | + // clientHelloMsg is not included in the transcript. | ||
2165 | + msg, err := c.readHandshake(nil) | ||
2166 | if err != nil { | ||
2167 | return err | ||
2168 | } | ||
2169 | @@ -514,9 +520,10 @@ func illegalClientHelloChange(ch, ch1 *c | ||
2170 | func (hs *serverHandshakeStateTLS13) sendServerParameters() error { | ||
2171 | c := hs.c | ||
2172 | |||
2173 | - hs.transcript.Write(hs.clientHello.marshal()) | ||
2174 | - hs.transcript.Write(hs.hello.marshal()) | ||
2175 | - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { | ||
2176 | + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { | ||
2177 | + return err | ||
2178 | + } | ||
2179 | + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { | ||
2180 | return err | ||
2181 | } | ||
2182 | |||
2183 | @@ -559,8 +566,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2184 | encryptedExtensions.alpnProtocol = selectedProto | ||
2185 | c.clientProtocol = selectedProto | ||
2186 | |||
2187 | - hs.transcript.Write(encryptedExtensions.marshal()) | ||
2188 | - if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil { | ||
2189 | + if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil { | ||
2190 | return err | ||
2191 | } | ||
2192 | |||
2193 | @@ -589,8 +595,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2194 | certReq.certificateAuthorities = c.config.ClientCAs.Subjects() | ||
2195 | } | ||
2196 | |||
2197 | - hs.transcript.Write(certReq.marshal()) | ||
2198 | - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { | ||
2199 | + if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { | ||
2200 | return err | ||
2201 | } | ||
2202 | } | ||
2203 | @@ -601,8 +606,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2204 | certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 | ||
2205 | certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 | ||
2206 | |||
2207 | - hs.transcript.Write(certMsg.marshal()) | ||
2208 | - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { | ||
2209 | + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { | ||
2210 | return err | ||
2211 | } | ||
2212 | |||
2213 | @@ -633,8 +637,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2214 | } | ||
2215 | certVerifyMsg.signature = sig | ||
2216 | |||
2217 | - hs.transcript.Write(certVerifyMsg.marshal()) | ||
2218 | - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { | ||
2219 | + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { | ||
2220 | return err | ||
2221 | } | ||
2222 | |||
2223 | @@ -648,8 +651,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2224 | verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), | ||
2225 | } | ||
2226 | |||
2227 | - hs.transcript.Write(finished.marshal()) | ||
2228 | - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { | ||
2229 | + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { | ||
2230 | return err | ||
2231 | } | ||
2232 | |||
2233 | @@ -710,7 +712,9 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2234 | finishedMsg := &finishedMsg{ | ||
2235 | verifyData: hs.clientFinished, | ||
2236 | } | ||
2237 | - hs.transcript.Write(finishedMsg.marshal()) | ||
2238 | + if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { | ||
2239 | + return err | ||
2240 | + } | ||
2241 | |||
2242 | if !hs.shouldSendSessionTickets() { | ||
2243 | return nil | ||
2244 | @@ -735,8 +739,12 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2245 | SignedCertificateTimestamps: c.scts, | ||
2246 | }, | ||
2247 | } | ||
2248 | - var err error | ||
2249 | - m.label, err = c.encryptTicket(state.marshal()) | ||
2250 | + stateBytes, err := state.marshal() | ||
2251 | + if err != nil { | ||
2252 | + c.sendAlert(alertInternalError) | ||
2253 | + return err | ||
2254 | + } | ||
2255 | + m.label, err = c.encryptTicket(stateBytes) | ||
2256 | if err != nil { | ||
2257 | return err | ||
2258 | } | ||
2259 | @@ -755,7 +763,7 @@ func (hs *serverHandshakeStateTLS13) sen | ||
2260 | // ticket_nonce, which must be unique per connection, is always left at | ||
2261 | // zero because we only ever send one ticket per connection. | ||
2262 | |||
2263 | - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { | ||
2264 | + if _, err := c.writeHandshakeRecord(m, nil); err != nil { | ||
2265 | return err | ||
2266 | } | ||
2267 | |||
2268 | @@ -780,7 +788,7 @@ func (hs *serverHandshakeStateTLS13) rea | ||
2269 | // If we requested a client certificate, then the client must send a | ||
2270 | // certificate message. If it's empty, no CertificateVerify is sent. | ||
2271 | |||
2272 | - msg, err := c.readHandshake() | ||
2273 | + msg, err := c.readHandshake(hs.transcript) | ||
2274 | if err != nil { | ||
2275 | return err | ||
2276 | } | ||
2277 | @@ -790,7 +798,6 @@ func (hs *serverHandshakeStateTLS13) rea | ||
2278 | c.sendAlert(alertUnexpectedMessage) | ||
2279 | return unexpectedMessageError(certMsg, msg) | ||
2280 | } | ||
2281 | - hs.transcript.Write(certMsg.marshal()) | ||
2282 | |||
2283 | if err := c.processCertsFromClient(certMsg.certificate); err != nil { | ||
2284 | return err | ||
2285 | @@ -804,7 +811,10 @@ func (hs *serverHandshakeStateTLS13) rea | ||
2286 | } | ||
2287 | |||
2288 | if len(certMsg.certificate.Certificate) != 0 { | ||
2289 | - msg, err = c.readHandshake() | ||
2290 | + // certificateVerifyMsg is included in the transcript, but not until | ||
2291 | + // after we verify the handshake signature, since the state before | ||
2292 | + // this message was sent is used. | ||
2293 | + msg, err = c.readHandshake(nil) | ||
2294 | if err != nil { | ||
2295 | return err | ||
2296 | } | ||
2297 | @@ -835,7 +845,9 @@ func (hs *serverHandshakeStateTLS13) rea | ||
2298 | return errors.New("tls: invalid signature by the client certificate: " + err.Error()) | ||
2299 | } | ||
2300 | |||
2301 | - hs.transcript.Write(certVerify.marshal()) | ||
2302 | + if err := transcriptMsg(certVerify, hs.transcript); err != nil { | ||
2303 | + return err | ||
2304 | + } | ||
2305 | } | ||
2306 | |||
2307 | // If we waited until the client certificates to send session tickets, we | ||
2308 | @@ -850,7 +862,8 @@ func (hs *serverHandshakeStateTLS13) rea | ||
2309 | func (hs *serverHandshakeStateTLS13) readClientFinished() error { | ||
2310 | c := hs.c | ||
2311 | |||
2312 | - msg, err := c.readHandshake() | ||
2313 | + // finishedMsg is not included in the transcript. | ||
2314 | + msg, err := c.readHandshake(nil) | ||
2315 | if err != nil { | ||
2316 | return err | ||
2317 | } | ||
2318 | --- go.orig/src/crypto/tls/key_schedule.go | ||
2319 | +++ go/src/crypto/tls/key_schedule.go | ||
2320 | @@ -8,6 +8,7 @@ import ( | ||
2321 | "crypto/elliptic" | ||
2322 | "crypto/hmac" | ||
2323 | "errors" | ||
2324 | + "fmt" | ||
2325 | "hash" | ||
2326 | "io" | ||
2327 | "math/big" | ||
2328 | @@ -42,8 +43,24 @@ func (c *cipherSuiteTLS13) expandLabel(s | ||
2329 | hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { | ||
2330 | b.AddBytes(context) | ||
2331 | }) | ||
2332 | + hkdfLabelBytes, err := hkdfLabel.Bytes() | ||
2333 | + if err != nil { | ||
2334 | + // Rather than calling BytesOrPanic, we explicitly handle this error, in | ||
2335 | + // order to provide a reasonable error message. It should be basically | ||
2336 | + // impossible for this to panic, and routing errors back through the | ||
2337 | + // tree rooted in this function is quite painful. The labels are fixed | ||
2338 | + // size, and the context is either a fixed-length computed hash, or | ||
2339 | + // parsed from a field which has the same length limitation. As such, an | ||
2340 | + // error here is likely to only be caused during development. | ||
2341 | + // | ||
2342 | + // NOTE: another reasonable approach here might be to return a | ||
2343 | + // randomized slice if we encounter an error, which would break the | ||
2344 | + // connection, but avoid panicking. This would perhaps be safer but | ||
2345 | + // significantly more confusing to users. | ||
2346 | + panic(fmt.Errorf("failed to construct HKDF label: %s", err)) | ||
2347 | + } | ||
2348 | out := make([]byte, length) | ||
2349 | - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) | ||
2350 | + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) | ||
2351 | if err != nil || n != length { | ||
2352 | panic("tls: HKDF-Expand-Label invocation failed unexpectedly") | ||
2353 | } | ||
2354 | --- go.orig/src/crypto/tls/ticket.go | ||
2355 | +++ go/src/crypto/tls/ticket.go | ||
2356 | @@ -32,7 +32,7 @@ type sessionState struct { | ||
2357 | usedOldKey bool | ||
2358 | } | ||
2359 | |||
2360 | -func (m *sessionState) marshal() []byte { | ||
2361 | +func (m *sessionState) marshal() ([]byte, error) { | ||
2362 | var b cryptobyte.Builder | ||
2363 | b.AddUint16(m.vers) | ||
2364 | b.AddUint16(m.cipherSuite) | ||
2365 | @@ -47,7 +47,7 @@ func (m *sessionState) marshal() []byte | ||
2366 | }) | ||
2367 | } | ||
2368 | }) | ||
2369 | - return b.BytesOrPanic() | ||
2370 | + return b.Bytes() | ||
2371 | } | ||
2372 | |||
2373 | func (m *sessionState) unmarshal(data []byte) bool { | ||
2374 | @@ -86,7 +86,7 @@ type sessionStateTLS13 struct { | ||
2375 | certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; | ||
2376 | } | ||
2377 | |||
2378 | -func (m *sessionStateTLS13) marshal() []byte { | ||
2379 | +func (m *sessionStateTLS13) marshal() ([]byte, error) { | ||
2380 | var b cryptobyte.Builder | ||
2381 | b.AddUint16(VersionTLS13) | ||
2382 | b.AddUint8(0) // revision | ||
2383 | @@ -96,7 +96,7 @@ func (m *sessionStateTLS13) marshal() [] | ||
2384 | b.AddBytes(m.resumptionSecret) | ||
2385 | }) | ||
2386 | marshalCertificate(&b, m.certificate) | ||
2387 | - return b.BytesOrPanic() | ||
2388 | + return b.Bytes() | ||
2389 | } | ||
2390 | |||
2391 | func (m *sessionStateTLS13) unmarshal(data []byte) bool { | ||
diff --git a/meta/recipes-devtools/go/go-1.19/cve-2022-41725.patch b/meta/recipes-devtools/go/go-1.19/cve-2022-41725.patch new file mode 100644 index 0000000000..a71d07e3f1 --- /dev/null +++ b/meta/recipes-devtools/go/go-1.19/cve-2022-41725.patch | |||
@@ -0,0 +1,652 @@ | |||
1 | From 5c55ac9bf1e5f779220294c843526536605f42ab Mon Sep 17 00:00:00 2001 | ||
2 | From: Damien Neil <dneil@google.com> | ||
3 | Date: Wed, 25 Jan 2023 09:27:01 -0800 | ||
4 | Subject: [PATCH] [release-branch.go1.19] mime/multipart: limit memory/inode | ||
5 | consumption of ReadForm | ||
6 | |||
7 | Reader.ReadForm is documented as storing "up to maxMemory bytes + 10MB" | ||
8 | in memory. Parsed forms can consume substantially more memory than | ||
9 | this limit, since ReadForm does not account for map entry overhead | ||
10 | and MIME headers. | ||
11 | |||
12 | In addition, while the amount of disk memory consumed by ReadForm can | ||
13 | be constrained by limiting the size of the parsed input, ReadForm will | ||
14 | create one temporary file per form part stored on disk, potentially | ||
15 | consuming a large number of inodes. | ||
16 | |||
17 | Update ReadForm's memory accounting to include part names, | ||
18 | MIME headers, and map entry overhead. | ||
19 | |||
20 | Update ReadForm to store all on-disk file parts in a single | ||
21 | temporary file. | ||
22 | |||
23 | Files returned by FileHeader.Open are documented as having a concrete | ||
24 | type of *os.File when a file is stored on disk. The change to use a | ||
25 | single temporary file for all parts means that this is no longer the | ||
26 | case when a form contains more than a single file part stored on disk. | ||
27 | |||
28 | The previous behavior of storing each file part in a separate disk | ||
29 | file may be reenabled with GODEBUG=multipartfiles=distinct. | ||
30 | |||
31 | Update Reader.NextPart and Reader.NextRawPart to set a 10MiB cap | ||
32 | on the size of MIME headers. | ||
33 | |||
34 | Thanks to Jakob Ackermann (@das7pad) for reporting this issue. | ||
35 | |||
36 | Updates #58006 | ||
37 | Fixes #58362 | ||
38 | Fixes CVE-2022-41725 | ||
39 | |||
40 | Change-Id: Ibd780a6c4c83ac8bcfd3cbe344f042e9940f2eab | ||
41 | Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1714276 | ||
42 | Reviewed-by: Julie Qiu <julieqiu@google.com> | ||
43 | TryBot-Result: Security TryBots <security-trybots@go-security-trybots.iam.gserviceaccount.com> | ||
44 | Reviewed-by: Roland Shoemaker <bracewell@google.com> | ||
45 | Run-TryBot: Damien Neil <dneil@google.com> | ||
46 | (cherry picked from commit ed4664330edcd91b24914c9371c377c132dbce8c) | ||
47 | Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728949 | ||
48 | Reviewed-by: Tatiana Bradley <tatianabradley@google.com> | ||
49 | Run-TryBot: Roland Shoemaker <bracewell@google.com> | ||
50 | Reviewed-by: Damien Neil <dneil@google.com> | ||
51 | Reviewed-on: https://go-review.googlesource.com/c/go/+/468116 | ||
52 | TryBot-Result: Gopher Robot <gobot@golang.org> | ||
53 | Reviewed-by: Than McIntosh <thanm@google.com> | ||
54 | Run-TryBot: Michael Pratt <mpratt@google.com> | ||
55 | Auto-Submit: Michael Pratt <mpratt@google.com> | ||
56 | --- | ||
57 | |||
58 | CVE: CVE-2022-41725 | ||
59 | |||
60 | Upstream-Status: Backport [see text] | ||
61 | |||
62 | https://github.com/golong/go.git commit 5c55ac9bf1e5... | ||
63 | modified for reader.go | ||
64 | |||
65 | Signed-off-by: Joe Slater <joe.slater@windriver.com> | ||
66 | |||
67 | ___ | ||
68 | src/mime/multipart/formdata.go | 132 ++++++++++++++++++++----- | ||
69 | src/mime/multipart/formdata_test.go | 140 ++++++++++++++++++++++++++- | ||
70 | src/mime/multipart/multipart.go | 25 +++-- | ||
71 | src/mime/multipart/readmimeheader.go | 14 +++ | ||
72 | src/net/http/request_test.go | 2 +- | ||
73 | src/net/textproto/reader.go | 20 +++- | ||
74 | 6 files changed, 295 insertions(+), 38 deletions(-) | ||
75 | create mode 100644 src/mime/multipart/readmimeheader.go | ||
76 | |||
77 | --- go.orig/src/mime/multipart/formdata.go | ||
78 | +++ go/src/mime/multipart/formdata.go | ||
79 | @@ -7,6 +7,7 @@ package multipart | ||
80 | import ( | ||
81 | "bytes" | ||
82 | "errors" | ||
83 | + "internal/godebug" | ||
84 | "io" | ||
85 | "math" | ||
86 | "net/textproto" | ||
87 | @@ -33,23 +34,58 @@ func (r *Reader) ReadForm(maxMemory int6 | ||
88 | |||
89 | func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { | ||
90 | form := &Form{make(map[string][]string), make(map[string][]*FileHeader)} | ||
91 | + var ( | ||
92 | + file *os.File | ||
93 | + fileOff int64 | ||
94 | + ) | ||
95 | + numDiskFiles := 0 | ||
96 | + multipartFiles := godebug.Get("multipartfiles") | ||
97 | + combineFiles := multipartFiles != "distinct" | ||
98 | defer func() { | ||
99 | + if file != nil { | ||
100 | + if cerr := file.Close(); err == nil { | ||
101 | + err = cerr | ||
102 | + } | ||
103 | + } | ||
104 | + if combineFiles && numDiskFiles > 1 { | ||
105 | + for _, fhs := range form.File { | ||
106 | + for _, fh := range fhs { | ||
107 | + fh.tmpshared = true | ||
108 | + } | ||
109 | + } | ||
110 | + } | ||
111 | if err != nil { | ||
112 | form.RemoveAll() | ||
113 | + if file != nil { | ||
114 | + os.Remove(file.Name()) | ||
115 | + } | ||
116 | } | ||
117 | }() | ||
118 | |||
119 | - // Reserve an additional 10 MB for non-file parts. | ||
120 | - maxValueBytes := maxMemory + int64(10<<20) | ||
121 | - if maxValueBytes <= 0 { | ||
122 | + // maxFileMemoryBytes is the maximum bytes of file data we will store in memory. | ||
123 | + // Data past this limit is written to disk. | ||
124 | + // This limit strictly applies to content, not metadata (filenames, MIME headers, etc.), | ||
125 | + // since metadata is always stored in memory, not disk. | ||
126 | + // | ||
127 | + // maxMemoryBytes is the maximum bytes we will store in memory, including file content, | ||
128 | + // non-file part values, metdata, and map entry overhead. | ||
129 | + // | ||
130 | + // We reserve an additional 10 MB in maxMemoryBytes for non-file data. | ||
131 | + // | ||
132 | + // The relationship between these parameters, as well as the overly-large and | ||
133 | + // unconfigurable 10 MB added on to maxMemory, is unfortunate but difficult to change | ||
134 | + // within the constraints of the API as documented. | ||
135 | + maxFileMemoryBytes := maxMemory | ||
136 | + maxMemoryBytes := maxMemory + int64(10<<20) | ||
137 | + if maxMemoryBytes <= 0 { | ||
138 | if maxMemory < 0 { | ||
139 | - maxValueBytes = 0 | ||
140 | + maxMemoryBytes = 0 | ||
141 | } else { | ||
142 | - maxValueBytes = math.MaxInt64 | ||
143 | + maxMemoryBytes = math.MaxInt64 | ||
144 | } | ||
145 | } | ||
146 | for { | ||
147 | - p, err := r.NextPart() | ||
148 | + p, err := r.nextPart(false, maxMemoryBytes) | ||
149 | if err == io.EOF { | ||
150 | break | ||
151 | } | ||
152 | @@ -63,16 +99,27 @@ func (r *Reader) readForm(maxMemory int6 | ||
153 | } | ||
154 | filename := p.FileName() | ||
155 | |||
156 | + // Multiple values for the same key (one map entry, longer slice) are cheaper | ||
157 | + // than the same number of values for different keys (many map entries), but | ||
158 | + // using a consistent per-value cost for overhead is simpler. | ||
159 | + maxMemoryBytes -= int64(len(name)) | ||
160 | + maxMemoryBytes -= 100 // map overhead | ||
161 | + if maxMemoryBytes < 0 { | ||
162 | + // We can't actually take this path, since nextPart would already have | ||
163 | + // rejected the MIME headers for being too large. Check anyway. | ||
164 | + return nil, ErrMessageTooLarge | ||
165 | + } | ||
166 | + | ||
167 | var b bytes.Buffer | ||
168 | |||
169 | if filename == "" { | ||
170 | // value, store as string in memory | ||
171 | - n, err := io.CopyN(&b, p, maxValueBytes+1) | ||
172 | + n, err := io.CopyN(&b, p, maxMemoryBytes+1) | ||
173 | if err != nil && err != io.EOF { | ||
174 | return nil, err | ||
175 | } | ||
176 | - maxValueBytes -= n | ||
177 | - if maxValueBytes < 0 { | ||
178 | + maxMemoryBytes -= n | ||
179 | + if maxMemoryBytes < 0 { | ||
180 | return nil, ErrMessageTooLarge | ||
181 | } | ||
182 | form.Value[name] = append(form.Value[name], b.String()) | ||
183 | @@ -80,35 +127,45 @@ func (r *Reader) readForm(maxMemory int6 | ||
184 | } | ||
185 | |||
186 | // file, store in memory or on disk | ||
187 | + maxMemoryBytes -= mimeHeaderSize(p.Header) | ||
188 | + if maxMemoryBytes < 0 { | ||
189 | + return nil, ErrMessageTooLarge | ||
190 | + } | ||
191 | fh := &FileHeader{ | ||
192 | Filename: filename, | ||
193 | Header: p.Header, | ||
194 | } | ||
195 | - n, err := io.CopyN(&b, p, maxMemory+1) | ||
196 | + n, err := io.CopyN(&b, p, maxFileMemoryBytes+1) | ||
197 | if err != nil && err != io.EOF { | ||
198 | return nil, err | ||
199 | } | ||
200 | - if n > maxMemory { | ||
201 | - // too big, write to disk and flush buffer | ||
202 | - file, err := os.CreateTemp("", "multipart-") | ||
203 | - if err != nil { | ||
204 | - return nil, err | ||
205 | + if n > maxFileMemoryBytes { | ||
206 | + if file == nil { | ||
207 | + file, err = os.CreateTemp(r.tempDir, "multipart-") | ||
208 | + if err != nil { | ||
209 | + return nil, err | ||
210 | + } | ||
211 | } | ||
212 | + numDiskFiles++ | ||
213 | size, err := io.Copy(file, io.MultiReader(&b, p)) | ||
214 | - if cerr := file.Close(); err == nil { | ||
215 | - err = cerr | ||
216 | - } | ||
217 | if err != nil { | ||
218 | - os.Remove(file.Name()) | ||
219 | return nil, err | ||
220 | } | ||
221 | fh.tmpfile = file.Name() | ||
222 | fh.Size = size | ||
223 | + fh.tmpoff = fileOff | ||
224 | + fileOff += size | ||
225 | + if !combineFiles { | ||
226 | + if err := file.Close(); err != nil { | ||
227 | + return nil, err | ||
228 | + } | ||
229 | + file = nil | ||
230 | + } | ||
231 | } else { | ||
232 | fh.content = b.Bytes() | ||
233 | fh.Size = int64(len(fh.content)) | ||
234 | - maxMemory -= n | ||
235 | - maxValueBytes -= n | ||
236 | + maxFileMemoryBytes -= n | ||
237 | + maxMemoryBytes -= n | ||
238 | } | ||
239 | form.File[name] = append(form.File[name], fh) | ||
240 | } | ||
241 | @@ -116,6 +173,17 @@ func (r *Reader) readForm(maxMemory int6 | ||
242 | return form, nil | ||
243 | } | ||
244 | |||
245 | +func mimeHeaderSize(h textproto.MIMEHeader) (size int64) { | ||
246 | + for k, vs := range h { | ||
247 | + size += int64(len(k)) | ||
248 | + size += 100 // map entry overhead | ||
249 | + for _, v := range vs { | ||
250 | + size += int64(len(v)) | ||
251 | + } | ||
252 | + } | ||
253 | + return size | ||
254 | +} | ||
255 | + | ||
256 | // Form is a parsed multipart form. | ||
257 | // Its File parts are stored either in memory or on disk, | ||
258 | // and are accessible via the *FileHeader's Open method. | ||
259 | @@ -133,7 +201,7 @@ func (f *Form) RemoveAll() error { | ||
260 | for _, fh := range fhs { | ||
261 | if fh.tmpfile != "" { | ||
262 | e := os.Remove(fh.tmpfile) | ||
263 | - if e != nil && err == nil { | ||
264 | + if e != nil && !errors.Is(e, os.ErrNotExist) && err == nil { | ||
265 | err = e | ||
266 | } | ||
267 | } | ||
268 | @@ -148,15 +216,25 @@ type FileHeader struct { | ||
269 | Header textproto.MIMEHeader | ||
270 | Size int64 | ||
271 | |||
272 | - content []byte | ||
273 | - tmpfile string | ||
274 | + content []byte | ||
275 | + tmpfile string | ||
276 | + tmpoff int64 | ||
277 | + tmpshared bool | ||
278 | } | ||
279 | |||
280 | // Open opens and returns the FileHeader's associated File. | ||
281 | func (fh *FileHeader) Open() (File, error) { | ||
282 | if b := fh.content; b != nil { | ||
283 | r := io.NewSectionReader(bytes.NewReader(b), 0, int64(len(b))) | ||
284 | - return sectionReadCloser{r}, nil | ||
285 | + return sectionReadCloser{r, nil}, nil | ||
286 | + } | ||
287 | + if fh.tmpshared { | ||
288 | + f, err := os.Open(fh.tmpfile) | ||
289 | + if err != nil { | ||
290 | + return nil, err | ||
291 | + } | ||
292 | + r := io.NewSectionReader(f, fh.tmpoff, fh.Size) | ||
293 | + return sectionReadCloser{r, f}, nil | ||
294 | } | ||
295 | return os.Open(fh.tmpfile) | ||
296 | } | ||
297 | @@ -175,8 +253,12 @@ type File interface { | ||
298 | |||
299 | type sectionReadCloser struct { | ||
300 | *io.SectionReader | ||
301 | + io.Closer | ||
302 | } | ||
303 | |||
304 | func (rc sectionReadCloser) Close() error { | ||
305 | + if rc.Closer != nil { | ||
306 | + return rc.Closer.Close() | ||
307 | + } | ||
308 | return nil | ||
309 | } | ||
310 | --- go.orig/src/mime/multipart/formdata_test.go | ||
311 | +++ go/src/mime/multipart/formdata_test.go | ||
312 | @@ -6,8 +6,10 @@ package multipart | ||
313 | |||
314 | import ( | ||
315 | "bytes" | ||
316 | + "fmt" | ||
317 | "io" | ||
318 | "math" | ||
319 | + "net/textproto" | ||
320 | "os" | ||
321 | "strings" | ||
322 | "testing" | ||
323 | @@ -208,8 +210,8 @@ Content-Disposition: form-data; name="la | ||
324 | maxMemory int64 | ||
325 | err error | ||
326 | }{ | ||
327 | - {"smaller", 50, nil}, | ||
328 | - {"exact-fit", 25, nil}, | ||
329 | + {"smaller", 50 + int64(len("largetext")) + 100, nil}, | ||
330 | + {"exact-fit", 25 + int64(len("largetext")) + 100, nil}, | ||
331 | {"too-large", 0, ErrMessageTooLarge}, | ||
332 | } | ||
333 | for _, tc := range testCases { | ||
334 | @@ -224,7 +226,7 @@ Content-Disposition: form-data; name="la | ||
335 | defer f.RemoveAll() | ||
336 | } | ||
337 | if tc.err != err { | ||
338 | - t.Fatalf("ReadForm error - got: %v; expected: %v", tc.err, err) | ||
339 | + t.Fatalf("ReadForm error - got: %v; expected: %v", err, tc.err) | ||
340 | } | ||
341 | if err == nil { | ||
342 | if g := f.Value["largetext"][0]; g != largeTextValue { | ||
343 | @@ -234,3 +236,135 @@ Content-Disposition: form-data; name="la | ||
344 | }) | ||
345 | } | ||
346 | } | ||
347 | + | ||
348 | +// TestReadForm_MetadataTooLarge verifies that we account for the size of field names, | ||
349 | +// MIME headers, and map entry overhead while limiting the memory consumption of parsed forms. | ||
350 | +func TestReadForm_MetadataTooLarge(t *testing.T) { | ||
351 | + for _, test := range []struct { | ||
352 | + name string | ||
353 | + f func(*Writer) | ||
354 | + }{{ | ||
355 | + name: "large name", | ||
356 | + f: func(fw *Writer) { | ||
357 | + name := strings.Repeat("a", 10<<20) | ||
358 | + w, _ := fw.CreateFormField(name) | ||
359 | + w.Write([]byte("value")) | ||
360 | + }, | ||
361 | + }, { | ||
362 | + name: "large MIME header", | ||
363 | + f: func(fw *Writer) { | ||
364 | + h := make(textproto.MIMEHeader) | ||
365 | + h.Set("Content-Disposition", `form-data; name="a"`) | ||
366 | + h.Set("X-Foo", strings.Repeat("a", 10<<20)) | ||
367 | + w, _ := fw.CreatePart(h) | ||
368 | + w.Write([]byte("value")) | ||
369 | + }, | ||
370 | + }, { | ||
371 | + name: "many parts", | ||
372 | + f: func(fw *Writer) { | ||
373 | + for i := 0; i < 110000; i++ { | ||
374 | + w, _ := fw.CreateFormField("f") | ||
375 | + w.Write([]byte("v")) | ||
376 | + } | ||
377 | + }, | ||
378 | + }} { | ||
379 | + t.Run(test.name, func(t *testing.T) { | ||
380 | + var buf bytes.Buffer | ||
381 | + fw := NewWriter(&buf) | ||
382 | + test.f(fw) | ||
383 | + if err := fw.Close(); err != nil { | ||
384 | + t.Fatal(err) | ||
385 | + } | ||
386 | + fr := NewReader(&buf, fw.Boundary()) | ||
387 | + _, err := fr.ReadForm(0) | ||
388 | + if err != ErrMessageTooLarge { | ||
389 | + t.Errorf("fr.ReadForm() = %v, want ErrMessageTooLarge", err) | ||
390 | + } | ||
391 | + }) | ||
392 | + } | ||
393 | +} | ||
394 | + | ||
395 | +// TestReadForm_ManyFiles_Combined tests that a multipart form containing many files only | ||
396 | +// results in a single on-disk file. | ||
397 | +func TestReadForm_ManyFiles_Combined(t *testing.T) { | ||
398 | + const distinct = false | ||
399 | + testReadFormManyFiles(t, distinct) | ||
400 | +} | ||
401 | + | ||
402 | +// TestReadForm_ManyFiles_Distinct tests that setting GODEBUG=multipartfiles=distinct | ||
403 | +// results in every file in a multipart form being placed in a distinct on-disk file. | ||
404 | +func TestReadForm_ManyFiles_Distinct(t *testing.T) { | ||
405 | + t.Setenv("GODEBUG", "multipartfiles=distinct") | ||
406 | + const distinct = true | ||
407 | + testReadFormManyFiles(t, distinct) | ||
408 | +} | ||
409 | + | ||
410 | +func testReadFormManyFiles(t *testing.T, distinct bool) { | ||
411 | + var buf bytes.Buffer | ||
412 | + fw := NewWriter(&buf) | ||
413 | + const numFiles = 10 | ||
414 | + for i := 0; i < numFiles; i++ { | ||
415 | + name := fmt.Sprint(i) | ||
416 | + w, err := fw.CreateFormFile(name, name) | ||
417 | + if err != nil { | ||
418 | + t.Fatal(err) | ||
419 | + } | ||
420 | + w.Write([]byte(name)) | ||
421 | + } | ||
422 | + if err := fw.Close(); err != nil { | ||
423 | + t.Fatal(err) | ||
424 | + } | ||
425 | + fr := NewReader(&buf, fw.Boundary()) | ||
426 | + fr.tempDir = t.TempDir() | ||
427 | + form, err := fr.ReadForm(0) | ||
428 | + if err != nil { | ||
429 | + t.Fatal(err) | ||
430 | + } | ||
431 | + for i := 0; i < numFiles; i++ { | ||
432 | + name := fmt.Sprint(i) | ||
433 | + if got := len(form.File[name]); got != 1 { | ||
434 | + t.Fatalf("form.File[%q] has %v entries, want 1", name, got) | ||
435 | + } | ||
436 | + fh := form.File[name][0] | ||
437 | + file, err := fh.Open() | ||
438 | + if err != nil { | ||
439 | + t.Fatalf("form.File[%q].Open() = %v", name, err) | ||
440 | + } | ||
441 | + if distinct { | ||
442 | + if _, ok := file.(*os.File); !ok { | ||
443 | + t.Fatalf("form.File[%q].Open: %T, want *os.File", name, file) | ||
444 | + } | ||
445 | + } | ||
446 | + got, err := io.ReadAll(file) | ||
447 | + file.Close() | ||
448 | + if string(got) != name || err != nil { | ||
449 | + t.Fatalf("read form.File[%q]: %q, %v; want %q, nil", name, string(got), err, name) | ||
450 | + } | ||
451 | + } | ||
452 | + dir, err := os.Open(fr.tempDir) | ||
453 | + if err != nil { | ||
454 | + t.Fatal(err) | ||
455 | + } | ||
456 | + defer dir.Close() | ||
457 | + names, err := dir.Readdirnames(0) | ||
458 | + if err != nil { | ||
459 | + t.Fatal(err) | ||
460 | + } | ||
461 | + wantNames := 1 | ||
462 | + if distinct { | ||
463 | + wantNames = numFiles | ||
464 | + } | ||
465 | + if len(names) != wantNames { | ||
466 | + t.Fatalf("temp dir contains %v files; want 1", len(names)) | ||
467 | + } | ||
468 | + if err := form.RemoveAll(); err != nil { | ||
469 | + t.Fatalf("form.RemoveAll() = %v", err) | ||
470 | + } | ||
471 | + names, err = dir.Readdirnames(0) | ||
472 | + if err != nil { | ||
473 | + t.Fatal(err) | ||
474 | + } | ||
475 | + if len(names) != 0 { | ||
476 | + t.Fatalf("temp dir contains %v files; want 0", len(names)) | ||
477 | + } | ||
478 | +} | ||
479 | --- go.orig/src/mime/multipart/multipart.go | ||
480 | +++ go/src/mime/multipart/multipart.go | ||
481 | @@ -128,12 +128,12 @@ func (r *stickyErrorReader) Read(p []byt | ||
482 | return n, r.err | ||
483 | } | ||
484 | |||
485 | -func newPart(mr *Reader, rawPart bool) (*Part, error) { | ||
486 | +func newPart(mr *Reader, rawPart bool, maxMIMEHeaderSize int64) (*Part, error) { | ||
487 | bp := &Part{ | ||
488 | Header: make(map[string][]string), | ||
489 | mr: mr, | ||
490 | } | ||
491 | - if err := bp.populateHeaders(); err != nil { | ||
492 | + if err := bp.populateHeaders(maxMIMEHeaderSize); err != nil { | ||
493 | return nil, err | ||
494 | } | ||
495 | bp.r = partReader{bp} | ||
496 | @@ -149,12 +149,16 @@ func newPart(mr *Reader, rawPart bool) ( | ||
497 | return bp, nil | ||
498 | } | ||
499 | |||
500 | -func (bp *Part) populateHeaders() error { | ||
501 | +func (bp *Part) populateHeaders(maxMIMEHeaderSize int64) error { | ||
502 | r := textproto.NewReader(bp.mr.bufReader) | ||
503 | - header, err := r.ReadMIMEHeader() | ||
504 | + header, err := readMIMEHeader(r, maxMIMEHeaderSize) | ||
505 | if err == nil { | ||
506 | bp.Header = header | ||
507 | } | ||
508 | + // TODO: Add a distinguishable error to net/textproto. | ||
509 | + if err != nil && err.Error() == "message too large" { | ||
510 | + err = ErrMessageTooLarge | ||
511 | + } | ||
512 | return err | ||
513 | } | ||
514 | |||
515 | @@ -294,6 +298,7 @@ func (p *Part) Close() error { | ||
516 | // isn't supported. | ||
517 | type Reader struct { | ||
518 | bufReader *bufio.Reader | ||
519 | + tempDir string // used in tests | ||
520 | |||
521 | currentPart *Part | ||
522 | partsRead int | ||
523 | @@ -304,6 +309,10 @@ type Reader struct { | ||
524 | dashBoundary []byte // "--boundary" | ||
525 | } | ||
526 | |||
527 | +// maxMIMEHeaderSize is the maximum size of a MIME header we will parse, | ||
528 | +// including header keys, values, and map overhead. | ||
529 | +const maxMIMEHeaderSize = 10 << 20 | ||
530 | + | ||
531 | // NextPart returns the next part in the multipart or an error. | ||
532 | // When there are no more parts, the error io.EOF is returned. | ||
533 | // | ||
534 | @@ -311,7 +320,7 @@ type Reader struct { | ||
535 | // has a value of "quoted-printable", that header is instead | ||
536 | // hidden and the body is transparently decoded during Read calls. | ||
537 | func (r *Reader) NextPart() (*Part, error) { | ||
538 | - return r.nextPart(false) | ||
539 | + return r.nextPart(false, maxMIMEHeaderSize) | ||
540 | } | ||
541 | |||
542 | // NextRawPart returns the next part in the multipart or an error. | ||
543 | @@ -320,10 +329,10 @@ func (r *Reader) NextPart() (*Part, erro | ||
544 | // Unlike NextPart, it does not have special handling for | ||
545 | // "Content-Transfer-Encoding: quoted-printable". | ||
546 | func (r *Reader) NextRawPart() (*Part, error) { | ||
547 | - return r.nextPart(true) | ||
548 | + return r.nextPart(true, maxMIMEHeaderSize) | ||
549 | } | ||
550 | |||
551 | -func (r *Reader) nextPart(rawPart bool) (*Part, error) { | ||
552 | +func (r *Reader) nextPart(rawPart bool, maxMIMEHeaderSize int64) (*Part, error) { | ||
553 | if r.currentPart != nil { | ||
554 | r.currentPart.Close() | ||
555 | } | ||
556 | @@ -348,7 +357,7 @@ func (r *Reader) nextPart(rawPart bool) | ||
557 | |||
558 | if r.isBoundaryDelimiterLine(line) { | ||
559 | r.partsRead++ | ||
560 | - bp, err := newPart(r, rawPart) | ||
561 | + bp, err := newPart(r, rawPart, maxMIMEHeaderSize) | ||
562 | if err != nil { | ||
563 | return nil, err | ||
564 | } | ||
565 | --- /dev/null | ||
566 | +++ go/src/mime/multipart/readmimeheader.go | ||
567 | @@ -0,0 +1,14 @@ | ||
568 | +// Copyright 2023 The Go Authors. All rights reserved. | ||
569 | +// Use of this source code is governed by a BSD-style | ||
570 | +// license that can be found in the LICENSE file. | ||
571 | +package multipart | ||
572 | + | ||
573 | +import ( | ||
574 | + "net/textproto" | ||
575 | + _ "unsafe" // for go:linkname | ||
576 | +) | ||
577 | + | ||
578 | +// readMIMEHeader is defined in package net/textproto. | ||
579 | +// | ||
580 | +//go:linkname readMIMEHeader net/textproto.readMIMEHeader | ||
581 | +func readMIMEHeader(r *textproto.Reader, lim int64) (textproto.MIMEHeader, error) | ||
582 | --- go.orig/src/net/http/request_test.go | ||
583 | +++ go/src/net/http/request_test.go | ||
584 | @@ -1110,7 +1110,7 @@ func testMissingFile(t *testing.T, req * | ||
585 | t.Errorf("FormFile file = %v, want nil", f) | ||
586 | } | ||
587 | if fh != nil { | ||
588 | - t.Errorf("FormFile file header = %q, want nil", fh) | ||
589 | + t.Errorf("FormFile file header = %v, want nil", fh) | ||
590 | } | ||
591 | if err != ErrMissingFile { | ||
592 | t.Errorf("FormFile err = %q, want ErrMissingFile", err) | ||
593 | --- go.orig/src/net/textproto/reader.go | ||
594 | +++ go/src/net/textproto/reader.go | ||
595 | @@ -7,8 +7,10 @@ package textproto | ||
596 | import ( | ||
597 | "bufio" | ||
598 | "bytes" | ||
599 | + "errors" | ||
600 | "fmt" | ||
601 | "io" | ||
602 | + "math" | ||
603 | "strconv" | ||
604 | "strings" | ||
605 | "sync" | ||
606 | @@ -481,6 +483,12 @@ func (r *Reader) ReadDotLines() ([]strin | ||
607 | // } | ||
608 | // | ||
609 | func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { | ||
610 | + return readMIMEHeader(r, math.MaxInt64) | ||
611 | +} | ||
612 | + | ||
613 | +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. | ||
614 | +// It is called by the mime/multipart package. | ||
615 | +func readMIMEHeader(r *Reader, lim int64) (MIMEHeader, error) { | ||
616 | // Avoid lots of small slice allocations later by allocating one | ||
617 | // large one ahead of time which we'll cut up into smaller | ||
618 | // slices. If this isn't big enough later, we allocate small ones. | ||
619 | @@ -521,6 +529,16 @@ func (r *Reader) ReadMIMEHeader() (MIMEH | ||
620 | continue | ||
621 | } | ||
622 | |||
623 | + // backport 5c55ac9bf1e5f779220294c843526536605f42ab | ||
624 | + // | ||
625 | + // value is computed as | ||
626 | + // | ||
627 | + // value := string(bytes.TrimLeft(v, " \t")) | ||
628 | + // | ||
629 | + // in the original patch from 1.19. This relies on | ||
630 | + // 'v' which does not exist in 1.17. We leave the | ||
631 | + // 1.17 method unchanged. | ||
632 | + | ||
633 | // Skip initial spaces in value. | ||
634 | i++ // skip colon | ||
635 | for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { | ||
636 | @@ -529,6 +547,16 @@ func (r *Reader) ReadMIMEHeader() (MIMEH | ||
637 | value := string(kv[i:]) | ||
638 | |||
639 | vv := m[key] | ||
640 | + if vv == nil { | ||
641 | + lim -= int64(len(key)) | ||
642 | + lim -= 100 // map entry overhead | ||
643 | + } | ||
644 | + lim -= int64(len(value)) | ||
645 | + if lim < 0 { | ||
646 | + // TODO: This should be a distinguishable error (ErrMessageTooLarge) | ||
647 | + // to allow mime/multipart to detect it. | ||
648 | + return m, errors.New("message too large") | ||
649 | + } | ||
650 | if vv == nil && len(strs) > 0 { | ||
651 | // More than likely this will be a single-element key. | ||
652 | // Most headers aren't multi-valued. | ||