Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 83 additions & 18 deletions gigl/common/logger.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,88 @@
import json
import logging
import os
import pathlib
from datetime import datetime
import sys
from datetime import datetime, timezone
from typing import Any, MutableMapping, Optional

from google.cloud import logging as google_cloud_logging

_BASE_LOG_FILE_PATH = "/tmp/research/gbml/logs"

_PYTHON_LEVEL_TO_GCP_SEVERITY: dict[str, str] = {
"DEBUG": "DEBUG",
"INFO": "INFO",
"WARNING": "WARNING",
"ERROR": "ERROR",
"CRITICAL": "CRITICAL",
}

class Logger(logging.LoggerAdapter):
# Key used by Logger.process() to pass user-supplied extras to the formatter
# without mixing them into the LogRecord's built-in attributes.
_GCP_LABELS_RECORD_ATTR: str = "_gcp_labels"


class _GcpJsonFormatter(logging.Formatter):
"""A ``logging.Formatter`` that outputs one JSON object per line with
`GCP-recognized structured logging fields
<https://cloud.google.com/logging/docs/structured-logging>`_.

Fields emitted:

- ``severity`` -- mapped from the Python log level.
- ``message`` -- the formatted log message (with traceback appended when present).
- ``time`` -- ISO 8601 UTC timestamp.
- ``logging.googleapis.com/sourceLocation`` -- ``{file, line, function}``.
- ``logging.googleapis.com/labels`` -- any extra fields supplied via the
``extra`` dict on the ``Logger`` adapter. Omitted when there are no extras.
"""
GiGL's custom logger class used for local and cloud logging (VertexAI, Dataflow, etc.)

def format(self, record: logging.LogRecord) -> str:
"""Format *record* as a single-line JSON string.

Args:
record: The ``LogRecord`` to format.

Returns:
A JSON string (no trailing newline) suitable for writing to
``sys.stderr`` on GCP-managed environments.
"""
message = record.getMessage()

if record.exc_info and not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
message = f"{message}\n{record.exc_text}"
if record.stack_info:
message = f"{message}\n{record.stack_info}"

payload: dict[str, object] = {
"severity": _PYTHON_LEVEL_TO_GCP_SEVERITY.get(
record.levelname, record.levelname
),
"message": message,
"time": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"logging.googleapis.com/sourceLocation": {
"file": record.pathname,
"line": record.lineno,
"function": record.funcName,
},
}

labels: dict[str, object] = getattr(record, _GCP_LABELS_RECORD_ATTR, {})
if labels:
payload["logging.googleapis.com/labels"] = labels

return json.dumps(payload, ensure_ascii=False, default=str)


class Logger(logging.LoggerAdapter):
"""GiGL's custom logger class used for local and cloud logging (VertexAI, Dataflow, etc.).

Args:
logger (Optional[logging.Logger]): A custom logger to use. If not provided, the default logger will be created.
name (Optional[str]): The name to be used for the logger. By default uses "root".
log_to_file (bool): If True, logs will be written to a file. If False, logs will be written to the console.
extra (Optional[dict[str, Any]]): Extra information to be added to the log message.
logger: A custom logger to use. If not provided, the default logger will be created.
name: The name to be used for the logger. By default uses "root".
log_to_file: If True, logs will be written to a file. If False, logs will be written to the console.
extra: Extra information to be added to the log message.
"""

def __init__(
Expand All @@ -37,12 +103,11 @@ def _setup_logger(
) -> None:
handler: logging.Handler
if not logger.handlers:
if os.getenv("GAE_APPLICATION") or os.environ.get(
"KUBERNETES_SERVICE_HOST"
):
# Google Cloud Logging
client = google_cloud_logging.Client()
client.setup_logging(log_level=logging.INFO)
# Check if running on GCP.
if os.getenv("GAE_APPLICATION") or os.getenv("KUBERNETES_SERVICE_HOST"):
handler = logging.StreamHandler(stream=sys.stderr)
handler.setFormatter(_GcpJsonFormatter())
logger.addHandler(handler)
else:
# Logging locally. Set up logging to console or file
if log_to_file:
Expand All @@ -64,10 +129,10 @@ def _setup_logger(
logger.setLevel(logging.INFO)

def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Any:
merged: dict[str, Any] = dict(self.extra if self.extra else {})
if "extra" in kwargs:
kwargs["extra"].update(self.extra)
else:
kwargs["extra"] = self.extra
merged.update(kwargs["extra"])
kwargs["extra"] = {**merged, _GCP_LABELS_RECORD_ATTR: merged}
return msg, kwargs

def __getattr__(self, name: str):
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/common/logger_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import json
import logging
import sys

from gigl.common.logger import _GCP_LABELS_RECORD_ATTR, _GcpJsonFormatter
from tests.test_assets.test_case import TestCase


class GcpJsonFormatterTest(TestCase):
def setUp(self) -> None:
self.formatter = _GcpJsonFormatter()

def _make_record(
self,
message: str = "test message",
level: int = logging.INFO,
exc_info: object = None,
) -> logging.LogRecord:
"""Create a ``LogRecord`` with deterministic source location."""
record = logging.LogRecord(
name="test",
level=level,
pathname="test_file.py",
lineno=42,
msg=message,
args=None,
exc_info=exc_info, # type: ignore[arg-type]
)
record.funcName = "test_func"
return record

def test_basic_info_message_produces_valid_json(self) -> None:
record = self._make_record()
output = self.formatter.format(record)
parsed = json.loads(output)

self.assertEqual(parsed["severity"], "INFO")
self.assertEqual(parsed["message"], "test message")
self.assertIn("time", parsed)
self.assertIn("logging.googleapis.com/sourceLocation", parsed)

def test_severity_mapping_for_all_levels(self) -> None:
levels = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARNING",
logging.ERROR: "ERROR",
logging.CRITICAL: "CRITICAL",
}
for python_level, expected_severity in levels.items():
with self.subTest(level=python_level):
record = self._make_record(level=python_level)
parsed = json.loads(self.formatter.format(record))
self.assertEqual(parsed["severity"], expected_severity)

def test_output_is_single_line(self) -> None:
record = self._make_record()
output = self.formatter.format(record)
self.assertEqual(output.count("\n"), 0)

def test_time_field_is_iso_8601(self) -> None:
record = self._make_record()
parsed = json.loads(self.formatter.format(record))
time_str = parsed["time"]
# ISO 8601 with timezone: contains 'T' separator and '+' offset
self.assertIn("T", time_str)
self.assertIn("+", time_str)

def test_extra_fields_appear_under_labels(self) -> None:
record = self._make_record()
setattr(record, _GCP_LABELS_RECORD_ATTR, {"custom_key": "custom_value"})
parsed = json.loads(self.formatter.format(record))

labels = parsed["logging.googleapis.com/labels"]
self.assertEqual(labels["custom_key"], "custom_value")

def test_no_labels_key_when_no_extras(self) -> None:
record = self._make_record()
parsed = json.loads(self.formatter.format(record))
self.assertNotIn("logging.googleapis.com/labels", parsed)

def test_exception_traceback_in_message(self) -> None:
try:
raise ValueError("boom")
except ValueError:
exc_info = sys.exc_info()

record = self._make_record(exc_info=exc_info)
parsed = json.loads(self.formatter.format(record))

self.assertIn("ValueError: boom", parsed["message"])
self.assertIn("Traceback", parsed["message"])

def test_source_location_fields(self) -> None:
record = self._make_record()
parsed = json.loads(self.formatter.format(record))
source = parsed["logging.googleapis.com/sourceLocation"]

self.assertEqual(source["file"], "test_file.py")
self.assertEqual(source["line"], 42)
self.assertEqual(source["function"], "test_func")