Skip to content

Commit 384e17a

Browse files
Implement websockets models
1 parent 1c2d8bb commit 384e17a

File tree

9 files changed

+193
-0
lines changed

9 files changed

+193
-0
lines changed

python/ql/lib/semmle/python/Frameworks.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ private import semmle.python.frameworks.TRender
8989
private import semmle.python.frameworks.Twisted
9090
private import semmle.python.frameworks.Ujson
9191
private import semmle.python.frameworks.Urllib3
92+
private import semmle.python.frameworks.Websockets
9293
private import semmle.python.frameworks.Xmltodict
9394
private import semmle.python.frameworks.Yaml
9495
private import semmle.python.frameworks.Yarl
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/**
2+
* Provides definitions and modeling for the `websockets` PyPI package.
3+
*
4+
* See https://websockets.readthedocs.io/en/stable/
5+
*/
6+
7+
private import python
8+
private import semmle.python.dataflow.new.RemoteFlowSources
9+
private import semmle.python.Concepts
10+
private import semmle.python.ApiGraphs
11+
private import semmle.python.frameworks.internal.PoorMansFunctionResolution
12+
private import semmle.python.frameworks.internal.InstanceTaintStepsHelper
13+
14+
/**
15+
* Provides models for the `websockets` PyPI package.
16+
* See https://websockets.readthedocs.io/en/stable/
17+
*/
18+
module Websockets {
19+
private class HandlerArg extends DataFlow::Node {
20+
HandlerArg() {
21+
exists(DataFlow::CallCfgNode c |
22+
c =
23+
API::moduleImport("websockets")
24+
.getMember(["asyncio", "sync"])
25+
.getMember("server")
26+
.getMember(["serve", "unix_serve"])
27+
.getACall()
28+
|
29+
(this = c.getArg(0) or this = c.getArgByName("handler"))
30+
)
31+
}
32+
}
33+
34+
/** A websocket handler that is passed to `serve`. */
35+
// TODO: handlers defined via route maps, e.g. through `websockets.asyncio.router.route`, are more complex to handle.
36+
class WebSocketHandler extends Http::Server::RequestHandler::Range {
37+
WebSocketHandler() { poorMansFunctionTracker(this) = any(HandlerArg a) }
38+
39+
override Parameter getARoutedParameter() { result = this.getAnArg() }
40+
41+
override string getFramework() { result = "websockets" }
42+
}
43+
44+
module ServerConnection {
45+
/**
46+
* A source of instances of `websockets.asyncio.ServerConnection` and `websockets.threading.ServerConnection`, extend this class to model new instances.
47+
*
48+
* This can include instantiations of the class, return values from function
49+
* calls, or a special parameter that will be set when functions are called by an external
50+
* library.
51+
*
52+
* Use the predicate `WebSocket::instance()` to get references to instances of `websockets.asyncio.ServerConnection` and `websockets.threading.ServerConnection`.
53+
*/
54+
abstract class InstanceSource extends DataFlow::LocalSourceNode { }
55+
56+
/** Gets a reference to an instance of `websockets.asyncio.ServerConnection` or `websockets.threading.ServerConnection`. */
57+
private DataFlow::TypeTrackingNode instance(DataFlow::TypeTracker t) {
58+
t.start() and
59+
result instanceof InstanceSource
60+
or
61+
exists(DataFlow::TypeTracker t2 | result = instance(t2).track(t2, t))
62+
}
63+
64+
/** Gets a reference to an instance of `websockets.asyncio.ServerConnection` or `websockets.threading.ServerConnection`. */
65+
DataFlow::Node instance() { instance(DataFlow::TypeTracker::end()).flowsTo(result) }
66+
67+
private class HandlerParam extends DataFlow::Node, InstanceSource {
68+
HandlerParam() { exists(WebSocketHandler h | this = DataFlow::parameterNode(h.getArg(0))) }
69+
}
70+
71+
private class InstanceTaintSteps extends InstanceTaintStepsHelper {
72+
InstanceTaintSteps() { this = "websockets.asyncio.ServerConnection" }
73+
74+
override DataFlow::Node getInstance() { result = instance() }
75+
76+
override string getAttributeName() { none() }
77+
78+
override string getAsyncMethodName() { result = ["recv", "recv_streaming"] }
79+
80+
override string getMethodName() { result = ["recv", "recv_streaming"] }
81+
}
82+
}
83+
}

python/ql/test/library-tests/frameworks/websockets/ConceptsTest.expected

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import python
2+
import experimental.meta.ConceptsTest
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
argumentToEnsureNotTaintedNotMarkedAsSpurious
2+
untaintedArgumentToEnsureTaintedNotMarkedAsMissing
3+
testFailures
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import experimental.meta.InlineTaintTest
2+
import MakeInlineTaintTest<TestTaintTrackingConfig>
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import websockets.sync.server
2+
import websockets.sync.router
3+
from werkzeug.routing import Map, Rule
4+
5+
def arg_handler(websocket): # $ requestHandler routedParameter=websocket
6+
websocket.send("arg" + websocket.recv())
7+
8+
s1 = websockets.sync.server.serve(arg_handler, "localhost", 8000)
9+
10+
def kw_handler(websocket): # $ requestHandler routedParameter=websocket
11+
websocket.send("kw" + websocket.recv())
12+
13+
s2 = websockets.sync.server.serve(handler=kw_handler, host="localhost", port=8001)
14+
15+
def route_handler(websocket, x): # $ MISSING: requestHandler routedParameter=websocket routedParameter=x
16+
websocket.send(f"route {x} {websocket.recv()}")
17+
18+
s3 = websockets.sync.router.route(Map([
19+
Rule("/<string:x>", endpoint=route_handler)
20+
]), "localhost", 8002)
21+
22+
def unix_handler(websocket): # $ requestHandler routedParameter=websocket
23+
websocket.send("unix" + websocket.recv())
24+
25+
s4 = websockets.sync.server.unix_serve(unix_handler, path="/tmp/ws.sock")
26+
27+
def unix_route_handler(websocket, x): # $ MISSING: requestHandler routedParameter=websocket routedParameter=x
28+
websocket.send(f"unix route {x} {websocket.recv()}")
29+
30+
s5 = websockets.sync.router.unix_route(Map([
31+
Rule("/<string:x>", endpoint=unix_route_handler)
32+
]), path="/tmp/ws2.sock")
33+
34+
if __name__ == "__main__":
35+
import sys
36+
server = s1
37+
if len(sys.argv) > 1:
38+
if sys.argv[1] == "kw":
39+
server = s2
40+
elif sys.argv[1] == "route":
41+
server = s3
42+
elif sys.argv[1] == "unix":
43+
server = s4
44+
elif sys.argv[1] == "unix_route":
45+
server = s5
46+
server.serve_forever()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import websockets.asyncio.server
2+
import asyncio
3+
4+
def ensure_tainted(*args):
5+
print("tainted", args)
6+
7+
def ensure_not_tainted(*args):
8+
print("not tainted", args)
9+
10+
async def handler(websocket): # $ requestHandler routedParameter=websocket
11+
ensure_tainted(
12+
websocket, # $ tainted
13+
await websocket.recv() # $ tainted
14+
)
15+
16+
async for msg in websocket:
17+
ensure_tainted(msg) # $ tainted
18+
await websocket.send(msg)
19+
20+
async for msg in websocket.recv_streaming():
21+
ensure_tainted(msg) # $ tainted
22+
await websocket.send(msg)
23+
24+
25+
async def main():
26+
server = await websockets.asyncio.server.serve(handler, "localhost", 8000)
27+
await server.serve_forever()
28+
29+
if __name__ == "__main__":
30+
asyncio.run(main())
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import websockets.sync.server
2+
3+
def ensure_tainted(*args):
4+
print("tainted", args)
5+
6+
def ensure_not_tainted(*args):
7+
print("not tainted", args)
8+
9+
def handler(websocket): # $ requestHandler routedParameter=websocket
10+
ensure_tainted(
11+
websocket, # $ tainted
12+
websocket.recv() # $ tainted
13+
)
14+
15+
for msg in websocket:
16+
ensure_tainted(msg) # $ tainted
17+
websocket.send(msg)
18+
19+
for msg in websocket.recv_streaming():
20+
ensure_tainted(msg) # $ tainted
21+
websocket.send(msg)
22+
23+
24+
if __name__ == "__main__":
25+
server = websockets.sync.server.serve(handler, "localhost", 8000)
26+
server.serve_forever()

0 commit comments

Comments
 (0)