Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cms/grading/ParameterTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from abc import ABCMeta, abstractmethod

from jinja2 import Markup, Template
from jinja2 import Template
from markupsafe import Markup
import typing

if typing.TYPE_CHECKING:
Expand Down
3 changes: 1 addition & 2 deletions cms/io/web_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def __init__(
self._service = service
self._auth = auth
self._url_map = Map([Rule("/<service>/<int:shard>/<method>",
methods=["POST"], endpoint="rpc")],
encoding_errors="strict")
methods=["POST"], endpoint="rpc")])

def __call__(self, environ, start_response):
"""Execute this instance as a WSGI application.
Expand Down
2 changes: 1 addition & 1 deletion cms/io/web_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import tornado.wsgi
from gevent.pywsgi import WSGIServer
from werkzeug.contrib.fixers import ProxyFix
from werkzeug.middleware.proxy_fix import ProxyFix
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from werkzeug.middleware.shared_data import SharedDataMiddleware

Expand Down
55 changes: 32 additions & 23 deletions cms/server/admin/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,17 @@

from collections.abc import Callable
import json
import math
import typing

from werkzeug.contrib.securecookie import SecureCookie
from tornado.web import create_signed_value, decode_signed_value
from werkzeug.local import Local, LocalManager
from werkzeug.wrappers import Request, Response

from cms import config
from cmscommon.binary import hex_to_bin
from cmscommon.datetime import make_timestamp


class UTF8JSON:
@staticmethod
def dumps(d: object) -> bytes:
return json.dumps(d).encode('utf-8')

@staticmethod
def loads(e: bytes) -> object:
return json.loads(e.decode('utf-8'))


class JSONSecureCookie(SecureCookie):
serialization_method = UTF8JSON


class AWSAuthMiddleware:
"""Handler for the low-level tasks of admin authentication.

Expand Down Expand Up @@ -70,7 +57,7 @@ def __init__(self, app: Callable):
self.wsgi_app = self._local_manager.make_middleware(self.wsgi_app)

self._request: Request = self._local("request")
self._cookie: JSONSecureCookie = self._local("cookie")
self._cookie: dict[str, typing.Any] = self._local("cookie")

@property
def admin_id(self) -> int | None:
Expand Down Expand Up @@ -128,9 +115,20 @@ def wsgi_app(self, environ: dict, start_response: Callable):

"""
self._local.request = Request(environ)
self._local.cookie = JSONSecureCookie.load_cookie(
self._request, AWSAuthMiddleware.COOKIE,
hex_to_bin(config.web_server.secret_key))
cookie_str = decode_signed_value(
bytes.fromhex(config.web_server.secret_key),
AWSAuthMiddleware.COOKIE,
self._request.cookies.get(AWSAuthMiddleware.COOKIE),
# We do our own expiry checking, so an upper bound is fine here
max_age_days=math.ceil(
config.admin_web_server.cookie_duration / 60 / 60 / 24
),
)
if cookie_str is not None:
self._local.cookie = json.loads(cookie_str.decode())
else:
self._local.cookie = {}

self._verify_cookie()

def my_start_response(status, headers, exc_info=None):
Expand All @@ -142,9 +140,20 @@ def my_start_response(status, headers, exc_info=None):

"""
response = Response(status=status, headers=headers)
self._cookie.save_cookie(
response, AWSAuthMiddleware.COOKIE, httponly=True,
max_age=config.admin_web_server.cookie_duration)
# json.dumps doesn't like LocalProxy objects, so we grab the actual
# underlying value here with _get_current_object
cookie_str = json.dumps(self._cookie._get_current_object())
cookie_signed = create_signed_value(
bytes.fromhex(config.web_server.secret_key),
AWSAuthMiddleware.COOKIE,
cookie_str,
).decode()
response.set_cookie(
AWSAuthMiddleware.COOKIE,
cookie_signed,
httponly=True,
max_age=config.admin_web_server.cookie_duration,
)
return start_response(
status, response.headers.to_wsgi_list(), exc_info)

Expand Down
4 changes: 2 additions & 2 deletions cms/server/contest/jinja2_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

"""

from jinja2 import contextfilter, PackageLoader
from jinja2 import pass_context, PackageLoader

from cms.server.jinja2_toolbox import GLOBAL_ENVIRONMENT
from .formatting import format_token_rules, get_score_class
Expand All @@ -38,7 +38,7 @@ def instrument_cms_toolbox(env):
env.filters["extract_token_params"] = extract_token_params


@contextfilter
@pass_context
def wrapped_format_token_rules(ctx, tokens, t_type=None):
translation = ctx["translation"]
return format_token_rules(tokens, t_type, translation=translation)
Expand Down
34 changes: 17 additions & 17 deletions cms/server/jinja2_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
"""

from datetime import datetime, timedelta, tzinfo
from jinja2 import Environment, StrictUndefined, contextfilter, \
contextfunction, environmentfunction
from jinja2 import Environment, StrictUndefined, pass_context, \
pass_environment
from jinja2.runtime import Context
import markdown_it
import markupsafe
Expand All @@ -45,7 +45,7 @@
from cmscommon.mimetypes import get_type_for_file_name, get_icon_for_type


@contextfilter
@pass_context
def all_(ctx: Context, l: list, test: str | None = None, *args) -> bool:
"""Check if all elements of the given list pass the given test.

Expand All @@ -69,7 +69,7 @@ def all_(ctx: Context, l: list, test: str | None = None, *args) -> bool:
return True


@contextfilter
@pass_context
def any_(ctx: Context, l: list, test: str | None = None, *args) -> bool:
"""Check if any element of the given list passes the given test.

Expand All @@ -93,7 +93,7 @@ def any_(ctx: Context, l: list, test: str | None = None, *args) -> bool:
return False


@contextfilter
@pass_context
def dictselect(
ctx: Context, d: dict, test: str | None = None, *args, by: str = "key"
) -> dict:
Expand Down Expand Up @@ -122,7 +122,7 @@ def dictselect(
if ctx.call(test, {"key": k, "value": v}[by], *args))


@contextfunction
@pass_context
def today(ctx: Context, dt: datetime) -> bool:
"""Returns whether the given datetime is today.

Expand Down Expand Up @@ -185,7 +185,7 @@ def instrument_generic_toolbox(env: Environment):
env.tests["today"] = today


@environmentfunction
@pass_environment
def safe_get_task_type(env: Environment, *, dataset: Dataset):
try:
return dataset.task_type_object
Expand All @@ -195,7 +195,7 @@ def safe_get_task_type(env: Environment, *, dataset: Dataset):
return env.undefined("TaskType not found: %s" % err)


@environmentfunction
@pass_environment
def safe_get_score_type(env: Environment, *, dataset: Dataset):
try:
return dataset.score_type_object
Expand All @@ -215,59 +215,59 @@ def instrument_cms_toolbox(env: Environment):
env.filters["to_language"] = get_language


@contextfilter
@pass_context
def format_datetime(ctx: Context, dt: datetime):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
timezone: tzinfo = ctx.get("timezone", local_tz)
return translation.format_datetime(dt, timezone)


@contextfilter
@pass_context
def format_time(ctx: Context, dt: datetime):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
timezone: tzinfo = ctx.get("timezone", local_tz)
return translation.format_time(dt, timezone)


@contextfilter
@pass_context
def format_datetime_smart(ctx: Context, dt: datetime):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
now: datetime = ctx.get("now", make_datetime())
timezone: tzinfo = ctx.get("timezone", local_tz)
return translation.format_datetime_smart(dt, now, timezone)


@contextfilter
@pass_context
def format_timedelta(ctx: Context, td: timedelta):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
return translation.format_timedelta(td)


@contextfilter
@pass_context
def format_duration(ctx: Context, d: float, length: str = "short"):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
return translation.format_duration(d, length)


@contextfilter
@pass_context
def format_size(ctx: Context, s: int):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
return translation.format_size(s)


@contextfilter
@pass_context
def format_decimal(ctx: Context, n: int):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
return translation.format_decimal(n)


@contextfilter
@pass_context
def format_locale(ctx: Context, n: str):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
return translation.format_locale(n)


@contextfilter
@pass_context
def wrapped_format_status_text(ctx: Context, status_text: list[str]):
translation: Translation = ctx.get("translation", DEFAULT_TRANSLATION)
return format_status_text(status_text, translation=translation)
Expand Down
7 changes: 6 additions & 1 deletion cmscommon/eventsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ def wsgi_app(self, environ, start_response):
# XMLHttpRequest it has been probably sent from a polyfill (not
# from the native browser implementation) which will be able to
# read the response body only when it has been fully received.
if environ["SERVER_PROTOCOL"] != "HTTP/1.1" or request.is_xhr:

# XXX: this used to also check request.is_xhr, which was removed in a
# newer werkzeug version. But all modern browsers support SSE natively
# so this check isn't necessary nowadays. (Well, the http/1.1 check
# probably isn't necessary either, to be honest...)
if environ["SERVER_PROTOCOL"] != "HTTP/1.1":
one_shot = True
else:
one_shot = False
Expand Down
Loading
Loading