Skip to content

Commit 42c6006

Browse files
committed
Simplify ProtocolSwitchStrategy by Leveraging ProtocolVersionParser
Unify HTTP and TLS token parsing in the Upgrade header by replacing custom version parsing with ProtocolVersionParser. This change removes redundant code and ensures that only supported protocols (HTTP/ and TLS tokens) are accepted, while all other upgrade protocols are rejected as unsupported.
1 parent 5e62dac commit 42c6006

2 files changed

Lines changed: 219 additions & 20 deletions

File tree

httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java

Lines changed: 101 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,22 @@
2727
package org.apache.hc.client5.http.impl;
2828

2929
import java.util.Iterator;
30+
import java.util.concurrent.atomic.AtomicReference;
3031

3132
import org.apache.hc.core5.annotation.Internal;
33+
import org.apache.hc.core5.http.FormattedHeader;
34+
import org.apache.hc.core5.http.Header;
3235
import org.apache.hc.core5.http.HttpHeaders;
3336
import org.apache.hc.core5.http.HttpMessage;
37+
import org.apache.hc.core5.http.HttpVersion;
3438
import org.apache.hc.core5.http.ParseException;
3539
import org.apache.hc.core5.http.ProtocolException;
3640
import org.apache.hc.core5.http.ProtocolVersion;
37-
import org.apache.hc.core5.http.message.MessageSupport;
41+
import org.apache.hc.core5.http.ProtocolVersionParser;
3842
import org.apache.hc.core5.http.ssl.TLS;
43+
import org.apache.hc.core5.util.Args;
44+
import org.apache.hc.core5.util.CharArrayBuffer;
45+
import org.apache.hc.core5.util.Tokenizer;
3946

4047
/**
4148
* Protocol switch handler.
@@ -45,31 +52,106 @@
4552
@Internal
4653
public final class ProtocolSwitchStrategy {
4754

48-
enum ProtocolSwitch { FAILURE, TLS }
55+
private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE;
56+
57+
private static final Tokenizer TOKENIZER = Tokenizer.INSTANCE;
58+
59+
private static final Tokenizer.Delimiter UPGRADE_TOKEN_DELIMITER = Tokenizer.delimiters(',');
60+
61+
@FunctionalInterface
62+
private interface HeaderConsumer {
63+
void accept(CharSequence buffer, Tokenizer.Cursor cursor) throws ProtocolException;
64+
}
4965

5066
public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException {
51-
final Iterator<String> it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE);
67+
final AtomicReference<ProtocolVersion> tlsUpgrade = new AtomicReference<>();
5268

53-
ProtocolVersion tlsUpgrade = null;
54-
while (it.hasNext()) {
55-
final String token = it.next();
56-
if (token.startsWith("TLS")) {
57-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
58-
try {
59-
tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv"));
60-
} catch (final ParseException ex) {
61-
throw new ProtocolException("Invalid protocol: " + token);
69+
parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> {
70+
while (!cursor.atEnd()) {
71+
TOKENIZER.skipWhiteSpace(buffer, cursor);
72+
if (cursor.atEnd()) {
73+
break;
74+
}
75+
final int tokenStart = cursor.getPos();
76+
TOKENIZER.parseToken(buffer, cursor, UPGRADE_TOKEN_DELIMITER);
77+
final int tokenEnd = cursor.getPos();
78+
if (tokenStart < tokenEnd) {
79+
final ProtocolVersion version = parseProtocolToken(buffer, tokenStart, tokenEnd);
80+
if (version != null && "TLS".equalsIgnoreCase(version.getProtocol())) {
81+
tlsUpgrade.set(version);
82+
}
6283
}
63-
} else if (token.equals("HTTP/1.1")) {
64-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
84+
if (!cursor.atEnd()) {
85+
cursor.updatePos(cursor.getPos() + 1);
86+
}
87+
}
88+
});
89+
90+
final ProtocolVersion result = tlsUpgrade.get();
91+
if (result != null) {
92+
return result;
93+
} else {
94+
throw new ProtocolException("Invalid protocol switch response: no TLS version found");
95+
}
96+
}
97+
98+
private ProtocolVersion parseProtocolToken(final CharSequence buffer, final int start, final int end)
99+
throws ProtocolException {
100+
if (start >= end) {
101+
return null;
102+
}
103+
104+
if (end - start == 3) {
105+
final char c0 = buffer.charAt(start);
106+
final char c1 = buffer.charAt(start + 1);
107+
final char c2 = buffer.charAt(start + 2);
108+
if ((c0 == 'T' || c0 == 't') &&
109+
(c1 == 'L' || c1 == 'l') &&
110+
(c2 == 'S' || c2 == 's')) {
111+
return TLS.V_1_2.getVersion();
112+
}
113+
}
114+
115+
try {
116+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(start, end);
117+
final ProtocolVersion version = PROTOCOL_VERSION_PARSER.parse(buffer, cursor, null);
118+
119+
if ("TLS".equalsIgnoreCase(version.getProtocol())) {
120+
return version;
121+
} else if (version.equals(HttpVersion.HTTP_1_1)) {
122+
return null;
65123
} else {
66-
throw new ProtocolException("Unsupported protocol: " + token);
124+
throw new ProtocolException("Unsupported protocol or HTTP version: " + buffer.subSequence(start, end));
67125
}
126+
} catch (final ParseException ex) {
127+
throw new ProtocolException("Invalid protocol: " + buffer.subSequence(start, end), ex);
68128
}
69-
if (tlsUpgrade == null) {
70-
throw new ProtocolException("Invalid protocol switch response");
129+
}
130+
131+
private void parseHeaders(final HttpMessage message, final String name, final HeaderConsumer consumer)
132+
throws ProtocolException {
133+
Args.notNull(message, "Message headers");
134+
Args.notBlank(name, "Header name");
135+
final Iterator<Header> it = message.headerIterator(name);
136+
while (it.hasNext()) {
137+
parseHeader(it.next(), consumer);
71138
}
72-
return tlsUpgrade;
73139
}
74140

75-
}
141+
private void parseHeader(final Header header, final HeaderConsumer consumer) throws ProtocolException {
142+
Args.notNull(header, "Header");
143+
if (header instanceof FormattedHeader) {
144+
final CharArrayBuffer buf = ((FormattedHeader) header).getBuffer();
145+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, buf.length());
146+
cursor.updatePos(((FormattedHeader) header).getValuePos());
147+
consumer.accept(buf, cursor);
148+
} else {
149+
final String value = header.getValue();
150+
if (value == null) {
151+
return;
152+
}
153+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, value.length());
154+
consumer.accept(value, cursor);
155+
}
156+
}
157+
}

httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
import org.apache.hc.core5.http.HttpResponse;
3131
import org.apache.hc.core5.http.HttpStatus;
3232
import org.apache.hc.core5.http.ProtocolException;
33+
import org.apache.hc.core5.http.ProtocolVersion;
3334
import org.apache.hc.core5.http.message.BasicHttpResponse;
3435
import org.apache.hc.core5.http.ssl.TLS;
3536
import org.junit.jupiter.api.Assertions;
3637
import org.junit.jupiter.api.BeforeEach;
3738
import org.junit.jupiter.api.Test;
3839

3940
/**
40-
* Simple tests for {@link DefaultAuthenticationStrategy}.
41+
* Simple tests for {@link ProtocolSwitchStrategy}.
4142
*/
4243
class TestProtocolSwitchStrategy {
4344

@@ -95,4 +96,120 @@ void testSwitchInvalid() {
9596
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3));
9697
}
9798

99+
@Test
100+
void testNullToken() throws ProtocolException {
101+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
102+
response.addHeader(HttpHeaders.UPGRADE, "TLS,");
103+
response.addHeader(HttpHeaders.UPGRADE, null);
104+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
105+
}
106+
107+
@Test
108+
void testWhitespaceOnlyToken() throws ProtocolException {
109+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
110+
response.addHeader(HttpHeaders.UPGRADE, " , TLS");
111+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
112+
}
113+
114+
@Test
115+
void testUnsupportedTlsVersion() throws Exception {
116+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
117+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4");
118+
Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response));
119+
}
120+
121+
@Test
122+
void testUnsupportedTlsMajorVersion() throws Exception {
123+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
124+
response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0");
125+
Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response));
126+
}
127+
128+
@Test
129+
void testUnsupportedHttpVersion() {
130+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
131+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0");
132+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
133+
"Unsupported HTTP version: HTTP/2.0");
134+
}
135+
136+
@Test
137+
void testInvalidTlsFormat() {
138+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
139+
response.addHeader(HttpHeaders.UPGRADE, "TLS/abc");
140+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
141+
"Invalid protocol: TLS/abc");
142+
}
143+
144+
@Test
145+
void testHttp11Only() {
146+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
147+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
148+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
149+
"Invalid protocol switch response: no TLS version found");
150+
}
151+
152+
@Test
153+
void testSwitchToTlsValid_TLS_1_2() throws Exception {
154+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
155+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
156+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
157+
Assertions.assertEquals(TLS.V_1_2.getVersion(), result);
158+
}
159+
160+
@Test
161+
void testSwitchToTlsValid_TLS_1_0() throws Exception {
162+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
163+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0");
164+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
165+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
166+
}
167+
168+
@Test
169+
void testSwitchToTlsValid_TLS_1_1() throws Exception {
170+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
171+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1");
172+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
173+
Assertions.assertEquals(TLS.V_1_1.getVersion(), result);
174+
}
175+
176+
@Test
177+
void testInvalidTlsFormat_NoSlash() {
178+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
179+
response.addHeader(HttpHeaders.UPGRADE, "TLSv1");
180+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
181+
"Invalid protocol: TLSv1");
182+
}
183+
184+
@Test
185+
void testSwitchToTlsValid_TLS_1() throws Exception {
186+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
187+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1");
188+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
189+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
190+
}
191+
192+
@Test
193+
void testInvalidTlsFormat_MissingMajor() {
194+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
195+
response.addHeader(HttpHeaders.UPGRADE, "TLS/.1");
196+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
197+
"Invalid protocol: TLS/.1");
198+
}
199+
200+
@Test
201+
void testMultipleHttp11Tokens() {
202+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
203+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1");
204+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
205+
"Invalid protocol switch response: no TLS version found");
206+
}
207+
208+
@Test
209+
void testMixedInvalidAndValidTokens() {
210+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
211+
response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid");
212+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
213+
"Invalid protocol: Crap");
214+
}
98215
}

0 commit comments

Comments
 (0)