@@ -117,13 +117,22 @@ class HTTPServer(socketserver.TCPServer):
117117 allow_reuse_address = True # Seems to make sense in testing environment
118118 allow_reuse_port = True
119119
120+ def __init__ (self , * args , response_headers = None , ** kwargs ):
121+ self .response_headers = response_headers if response_headers is not None else {}
122+ super ().__init__ (* args , ** kwargs )
123+
120124 def server_bind (self ):
121125 """Override server_bind to store the server name."""
122126 socketserver .TCPServer .server_bind (self )
123127 host , port = self .server_address [:2 ]
124128 self .server_name = socket .getfqdn (host )
125129 self .server_port = port
126130
131+ def finish_request (self , request , client_address ):
132+ """Finish one request by instantiating RequestHandlerClass."""
133+ args = (request , client_address , self )
134+ kwargs = dict (response_headers = self .response_headers ) if self .response_headers else dict ()
135+ self .RequestHandlerClass (* args , ** kwargs )
127136
128137class ThreadingHTTPServer (socketserver .ThreadingMixIn , HTTPServer ):
129138 daemon_threads = True
@@ -132,7 +141,7 @@ class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
132141class HTTPSServer (HTTPServer ):
133142 def __init__ (self , server_address , RequestHandlerClass ,
134143 bind_and_activate = True , * , certfile , keyfile = None ,
135- password = None , alpn_protocols = None ):
144+ password = None , alpn_protocols = None , response_headers = None ):
136145 try :
137146 import ssl
138147 except ImportError :
@@ -150,7 +159,8 @@ def __init__(self, server_address, RequestHandlerClass,
150159
151160 super ().__init__ (server_address ,
152161 RequestHandlerClass ,
153- bind_and_activate )
162+ bind_and_activate ,
163+ response_headers = response_headers )
154164
155165 def server_activate (self ):
156166 """Wrap the socket in SSLSocket."""
@@ -692,10 +702,11 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
692702 '.xz' : 'application/x-xz' ,
693703 }
694704
695- def __init__ (self , * args , directory = None , ** kwargs ):
705+ def __init__ (self , * args , directory = None , response_headers = None , ** kwargs ):
696706 if directory is None :
697707 directory = os .getcwd ()
698708 self .directory = os .fspath (directory )
709+ self .response_headers = response_headers or {}
699710 super ().__init__ (* args , ** kwargs )
700711
701712 def do_GET (self ):
@@ -736,6 +747,10 @@ def send_head(self):
736747 new_url = urllib .parse .urlunsplit (new_parts )
737748 self .send_header ("Location" , new_url )
738749 self .send_header ("Content-Length" , "0" )
750+ # User specified response_headers
751+ if self .response_headers is not None :
752+ for header , value in self .response_headers .items ():
753+ self .send_header (header , value )
739754 self .end_headers ()
740755 return None
741756 for index in self .index_pages :
@@ -795,6 +810,9 @@ def send_head(self):
795810 self .send_header ("Content-Length" , str (fs [6 ]))
796811 self .send_header ("Last-Modified" ,
797812 self .date_time_string (fs .st_mtime ))
813+ if self .response_headers is not None :
814+ for header , value in self .response_headers .items ():
815+ self .send_header (header , value )
798816 self .end_headers ()
799817 return f
800818 except :
@@ -970,7 +988,7 @@ def _get_best_family(*address):
970988def test (HandlerClass = BaseHTTPRequestHandler ,
971989 ServerClass = ThreadingHTTPServer ,
972990 protocol = "HTTP/1.0" , port = 8000 , bind = None ,
973- tls_cert = None , tls_key = None , tls_password = None ):
991+ tls_cert = None , tls_key = None , tls_password = None , response_headers = None ):
974992 """Test the HTTP request handler class.
975993
976994 This runs an HTTP server on port 8000 (or the port argument).
@@ -981,9 +999,10 @@ def test(HandlerClass=BaseHTTPRequestHandler,
981999
9821000 if tls_cert :
9831001 server = ServerClass (addr , HandlerClass , certfile = tls_cert ,
984- keyfile = tls_key , password = tls_password )
1002+ keyfile = tls_key , password = tls_password ,
1003+ response_headers = response_headers )
9851004 else :
986- server = ServerClass (addr , HandlerClass )
1005+ server = ServerClass (addr , HandlerClass , response_headers = response_headers )
9871006
9881007 with server as httpd :
9891008 host , port = httpd .socket .getsockname ()[:2 ]
@@ -1024,6 +1043,8 @@ def _main(args=None):
10241043 parser .add_argument ('port' , default = 8000 , type = int , nargs = '?' ,
10251044 help = 'bind to this port '
10261045 '(default: %(default)s)' )
1046+ parser .add_argument ('--cors' , action = 'store_true' ,
1047+ help = 'Enable Access-Control-Allow-Origin: * header' )
10271048 args = parser .parse_args (args )
10281049
10291050 if not args .tls_cert and args .tls_key :
@@ -1051,15 +1072,19 @@ def server_bind(self):
10511072 return super ().server_bind ()
10521073
10531074 def finish_request (self , request , client_address ):
1054- self .RequestHandlerClass (request , client_address , self ,
1055- directory = args .directory )
1075+ handler_args = (request , client_address , self )
1076+ handler_kwargs = dict (directory = args .directory )
1077+ if self .response_headers :
1078+ handler_kwargs ['response_headers' ] = self .response_headers
1079+ self .RequestHandlerClass (* handler_args , ** handler_kwargs )
10561080
10571081 class HTTPDualStackServer (DualStackServerMixin , ThreadingHTTPServer ):
10581082 pass
10591083 class HTTPSDualStackServer (DualStackServerMixin , ThreadingHTTPSServer ):
10601084 pass
10611085
10621086 ServerClass = HTTPSDualStackServer if args .tls_cert else HTTPDualStackServer
1087+ response_headers = {'Access-Control-Allow-Origin' : '*' } if args .cors else None
10631088
10641089 test (
10651090 HandlerClass = SimpleHTTPRequestHandler ,
@@ -1070,6 +1095,7 @@ class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
10701095 tls_cert = args .tls_cert ,
10711096 tls_key = args .tls_key ,
10721097 tls_password = tls_key_password ,
1098+ response_headers = response_headers
10731099 )
10741100
10751101
0 commit comments