Skip to content

Commit 429759f

Browse files
committed
Simplify ProtocolSwitchStrategy by Leveraging ProtocolVersionParser
Unify HTTP and TLS token parsing in the Upgrade header by replacing custom version parsing with ProtocolVersionParser. This change removes redundant code and ensures that only supported protocols (HTTP/ and TLS tokens) are accepted, while all other upgrade protocols are rejected as unsupported.
1 parent 5e62dac commit 429759f

2 files changed

Lines changed: 161 additions & 16 deletions

File tree

httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@
3131
import org.apache.hc.core5.annotation.Internal;
3232
import org.apache.hc.core5.http.HttpHeaders;
3333
import org.apache.hc.core5.http.HttpMessage;
34+
import org.apache.hc.core5.http.HttpVersion;
3435
import org.apache.hc.core5.http.ParseException;
3536
import org.apache.hc.core5.http.ProtocolException;
3637
import org.apache.hc.core5.http.ProtocolVersion;
38+
import org.apache.hc.core5.http.ProtocolVersionParser;
3739
import org.apache.hc.core5.http.message.MessageSupport;
3840
import org.apache.hc.core5.http.ssl.TLS;
41+
import org.apache.hc.core5.util.CharArrayBuffer;
42+
import org.apache.hc.core5.util.Tokenizer;
3943

4044
/**
4145
* Protocol switch handler.
@@ -45,31 +49,55 @@
4549
@Internal
4650
public final class ProtocolSwitchStrategy {
4751

48-
enum ProtocolSwitch { FAILURE, TLS }
52+
private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE;
4953

5054
public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException {
5155
final Iterator<String> it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE);
5256

5357
ProtocolVersion tlsUpgrade = null;
58+
5459
while (it.hasNext()) {
5560
final String token = it.next();
56-
if (token.startsWith("TLS")) {
57-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
58-
try {
59-
tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv"));
60-
} catch (final ParseException ex) {
61-
throw new ProtocolException("Invalid protocol: " + token);
61+
final ProtocolVersion version = parseProtocolToken(token);
62+
if (version != null) {
63+
if ("TLS".equalsIgnoreCase(version.getProtocol())) {
64+
tlsUpgrade = version;
6265
}
63-
} else if (token.equals("HTTP/1.1")) {
64-
// TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method
65-
} else {
66-
throw new ProtocolException("Unsupported protocol: " + token);
6766
}
6867
}
69-
if (tlsUpgrade == null) {
70-
throw new ProtocolException("Invalid protocol switch response");
68+
69+
if (tlsUpgrade != null) {
70+
return tlsUpgrade;
71+
} else {
72+
throw new ProtocolException("Invalid protocol switch response: no TLS version found");
7173
}
72-
return tlsUpgrade;
7374
}
7475

75-
}
76+
private ProtocolVersion parseProtocolToken(final String token) throws ProtocolException {
77+
if (token == null || token.trim().isEmpty()) {
78+
return null;
79+
}
80+
81+
try {
82+
if ("TLS".equalsIgnoreCase(token)) {
83+
return TLS.V_1_2.getVersion();
84+
}
85+
86+
final CharArrayBuffer buffer = new CharArrayBuffer(token.length());
87+
buffer.append(token);
88+
final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, token.length());
89+
90+
final ProtocolVersion version = PROTOCOL_VERSION_PARSER.parse(buffer, cursor, null);
91+
92+
if ("TLS".equalsIgnoreCase(version.getProtocol())) {
93+
return version;
94+
} else if (version.equals(HttpVersion.HTTP_1_1)) {
95+
return null;
96+
} else {
97+
throw new ProtocolException("Unsupported protocol or HTTP version: " + token);
98+
}
99+
} catch (final ParseException ex) {
100+
throw new ProtocolException("Invalid protocol: " + token, ex);
101+
}
102+
}
103+
}

httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
import org.apache.hc.core5.http.HttpResponse;
3131
import org.apache.hc.core5.http.HttpStatus;
3232
import org.apache.hc.core5.http.ProtocolException;
33+
import org.apache.hc.core5.http.ProtocolVersion;
3334
import org.apache.hc.core5.http.message.BasicHttpResponse;
3435
import org.apache.hc.core5.http.ssl.TLS;
3536
import org.junit.jupiter.api.Assertions;
3637
import org.junit.jupiter.api.BeforeEach;
3738
import org.junit.jupiter.api.Test;
3839

3940
/**
40-
* Simple tests for {@link DefaultAuthenticationStrategy}.
41+
* Simple tests for {@link ProtocolSwitchStrategy}.
4142
*/
4243
class TestProtocolSwitchStrategy {
4344

@@ -95,4 +96,120 @@ void testSwitchInvalid() {
9596
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3));
9697
}
9798

99+
@Test
100+
void testNullToken() throws ProtocolException {
101+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
102+
response.addHeader(HttpHeaders.UPGRADE, "TLS,");
103+
response.addHeader(HttpHeaders.UPGRADE, null);
104+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
105+
}
106+
107+
@Test
108+
void testWhitespaceOnlyToken() throws ProtocolException {
109+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
110+
response.addHeader(HttpHeaders.UPGRADE, " , TLS");
111+
Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response));
112+
}
113+
114+
@Test
115+
void testUnsupportedTlsVersion() throws Exception {
116+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
117+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4");
118+
Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response));
119+
}
120+
121+
@Test
122+
void testUnsupportedTlsMajorVersion() throws Exception {
123+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
124+
response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0");
125+
Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response));
126+
}
127+
128+
@Test
129+
void testUnsupportedHttpVersion() {
130+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
131+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0");
132+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
133+
"Unsupported HTTP version: HTTP/2.0");
134+
}
135+
136+
@Test
137+
void testInvalidTlsFormat() {
138+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
139+
response.addHeader(HttpHeaders.UPGRADE, "TLS/abc");
140+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
141+
"Invalid protocol: TLS/abc");
142+
}
143+
144+
@Test
145+
void testHttp11Only() {
146+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
147+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1");
148+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
149+
"Invalid protocol switch response: no TLS version found");
150+
}
151+
152+
@Test
153+
void testSwitchToTlsValid_TLS_1_2() throws Exception {
154+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
155+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2");
156+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
157+
Assertions.assertEquals(TLS.V_1_2.getVersion(), result);
158+
}
159+
160+
@Test
161+
void testSwitchToTlsValid_TLS_1_0() throws Exception {
162+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
163+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0");
164+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
165+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
166+
}
167+
168+
@Test
169+
void testSwitchToTlsValid_TLS_1_1() throws Exception {
170+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
171+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1");
172+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
173+
Assertions.assertEquals(TLS.V_1_1.getVersion(), result);
174+
}
175+
176+
@Test
177+
void testInvalidTlsFormat_NoSlash() {
178+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
179+
response.addHeader(HttpHeaders.UPGRADE, "TLSv1");
180+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
181+
"Invalid protocol: TLSv1");
182+
}
183+
184+
@Test
185+
void testSwitchToTlsValid_TLS_1() throws Exception {
186+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
187+
response.addHeader(HttpHeaders.UPGRADE, "TLS/1");
188+
final ProtocolVersion result = switchStrategy.switchProtocol(response);
189+
Assertions.assertEquals(TLS.V_1_0.getVersion(), result);
190+
}
191+
192+
@Test
193+
void testInvalidTlsFormat_MissingMajor() {
194+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
195+
response.addHeader(HttpHeaders.UPGRADE, "TLS/.1");
196+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
197+
"Invalid protocol: TLS/.1");
198+
}
199+
200+
@Test
201+
void testMultipleHttp11Tokens() {
202+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
203+
response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1");
204+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
205+
"Invalid protocol switch response: no TLS version found");
206+
}
207+
208+
@Test
209+
void testMixedInvalidAndValidTokens() {
210+
final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS);
211+
response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid");
212+
Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response),
213+
"Invalid protocol: Crap");
214+
}
98215
}

0 commit comments

Comments
 (0)