Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions graphgen/bases/base_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union

import numpy as np
if TYPE_CHECKING:
import numpy as np


class BaseFilter(ABC):
Expand All @@ -15,7 +16,7 @@ def filter(self, data: Any) -> bool:

class BaseValueFilter(BaseFilter, ABC):
@abstractmethod
def filter(self, data: Union[int, float, np.number]) -> bool:
def filter(self, data: Union[int, float, "np.number"]) -> bool:
"""
Filter the numeric value and return True if it passes the filter, False otherwise.
"""
Expand Down
24 changes: 17 additions & 7 deletions graphgen/bases/base_operator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import inspect
import os
from abc import ABC, abstractmethod
from typing import Iterable, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Tuple, Union

import numpy as np
import pandas as pd
import ray
if TYPE_CHECKING:
import numpy as np
import pandas as pd


def convert_to_serializable(obj):
import numpy as np

if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
Expand Down Expand Up @@ -40,6 +44,8 @@ def __init__(
)

try:
import ray

ctx = ray.get_runtime_context()
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
worker_id_short = worker_id[-6:] if worker_id else "driver"
Expand All @@ -62,9 +68,11 @@ def __init__(
)

def __call__(
self, batch: pd.DataFrame
) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
self, batch: "pd.DataFrame"
) -> Union["pd.DataFrame", Iterable["pd.DataFrame"]]:
# lazy import to avoid circular import
import pandas as pd

from graphgen.utils import CURRENT_LOGGER_VAR

logger_token = CURRENT_LOGGER_VAR.set(self.logger)
Expand Down Expand Up @@ -106,14 +114,16 @@ def get_trace_id(self, content: dict) -> str:

return compute_dict_hash(content, prefix=f"{self.op_name}-")

def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]:
"""
Split the input batch into to_process & processed based on _meta data in KV_storage
:param batch
:return:
to_process: DataFrame of documents to be chunked
recovered: Result DataFrame of already chunked documents
"""
import pandas as pd

meta_forward = self.get_meta_forward()
meta_ids = set(meta_forward.keys())
mask = batch["_trace_id"].isin(meta_ids)
Expand Down
11 changes: 8 additions & 3 deletions graphgen/bases/base_reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Union

import pandas as pd
import requests
from ray.data import Dataset

if TYPE_CHECKING:
import pandas as pd
from ray.data import Dataset


class BaseReader(ABC):
Expand Down Expand Up @@ -51,6 +55,7 @@ def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
"""
Validate data format.
"""

if "type" not in batch.columns:
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")

Expand Down
11 changes: 7 additions & 4 deletions graphgen/common/init_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Any, Dict, Optional

import ray
from typing import TYPE_CHECKING, Any, Dict, Optional

from graphgen.bases import BaseLLMWrapper
from graphgen.models import Tokenizer

if TYPE_CHECKING:
import ray


class LLMServiceActor:
"""
Expand Down Expand Up @@ -73,7 +74,7 @@ class LLMServiceProxy(BaseLLMWrapper):
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
"""

def __init__(self, actor_handle: ray.actor.ActorHandle):
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
super().__init__()
self.actor_handle = actor_handle
self._create_local_tokenizer()
Expand Down Expand Up @@ -120,6 +121,8 @@ class LLMFactory:
def create_llm(
model_type: str, backend: str, config: Dict[str, Any]
) -> BaseLLMWrapper:
import ray

if not config:
raise ValueError(
f"No configuration provided for LLM {model_type} with backend {backend}."
Expand Down
22 changes: 21 additions & 1 deletion graphgen/common/init_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def ready(self) -> bool:


class RemoteKVStorageProxy(BaseKVStorage):
def __init__(self, actor_handle: ray.actor.ActorHandle):
def __init__(self, actor_handle: "ray.actor.ActorHandle"):
super().__init__()
self.actor = actor_handle

Expand Down Expand Up @@ -202,68 +202,87 @@ def get_all_node_degrees(self) -> Dict[str, int]:
return ray.get(self.actor.get_all_node_degrees.remote())

def get_node_count(self) -> int:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This blank line is unnecessary and harms readability by making the method less compact. This applies to many other simple proxy methods in this class as well (e.g., get_edge_count, has_node). For simple one-line proxy methods, it's best to keep them compact.

return ray.get(self.actor.get_node_count.remote())

def get_edge_count(self) -> int:

return ray.get(self.actor.get_edge_count.remote())

def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:

return ray.get(self.actor.get_connected_components.remote(undirected))

def has_node(self, node_id: str) -> bool:

return ray.get(self.actor.has_node.remote(node_id))

def has_edge(self, source_node_id: str, target_node_id: str):

return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))

def node_degree(self, node_id: str) -> int:

return ray.get(self.actor.node_degree.remote(node_id))

def edge_degree(self, src_id: str, tgt_id: str) -> int:

return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))

def get_node(self, node_id: str) -> Any:

return ray.get(self.actor.get_node.remote(node_id))

def update_node(self, node_id: str, node_data: dict[str, str]):

return ray.get(self.actor.update_node.remote(node_id, node_data))

def get_all_nodes(self) -> Any:

return ray.get(self.actor.get_all_nodes.remote())

def get_edge(self, source_node_id: str, target_node_id: str):

return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):

return ray.get(
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
)

def get_all_edges(self) -> Any:

return ray.get(self.actor.get_all_edges.remote())

def get_node_edges(self, source_node_id: str) -> Any:

return ray.get(self.actor.get_node_edges.remote(source_node_id))

def upsert_node(self, node_id: str, node_data: dict[str, str]):

return ray.get(self.actor.upsert_node.remote(node_id, node_data))

def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):

return ray.get(
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
)

def delete_node(self, node_id: str):

return ray.get(self.actor.delete_node.remote(node_id))

def get_neighbors(self, node_id: str) -> List[str]:

return ray.get(self.actor.get_neighbors.remote(node_id))

def reload(self):

return ray.get(self.actor.reload.remote())


Expand All @@ -274,6 +293,7 @@ class StorageFactory:

@staticmethod
def create_storage(backend: str, working_dir: str, namespace: str):

if backend in ["json_kv", "rocksdb"]:
actor_name = f"Actor_KV_{namespace}"
actor_class = KVStorageActor
Expand Down
Loading
Loading