diff --git a/graphgen/bases/base_filter.py b/graphgen/bases/base_filter.py index e46983e9..86a3edad 100644 --- a/graphgen/bases/base_filter.py +++ b/graphgen/bases/base_filter.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union -import numpy as np +if TYPE_CHECKING: + import numpy as np class BaseFilter(ABC): @@ -15,7 +16,7 @@ def filter(self, data: Any) -> bool: class BaseValueFilter(BaseFilter, ABC): @abstractmethod - def filter(self, data: Union[int, float, np.number]) -> bool: + def filter(self, data: Union[int, float, "np.number"]) -> bool: """ Filter the numeric value and return True if it passes the filter, False otherwise. """ diff --git a/graphgen/bases/base_operator.py b/graphgen/bases/base_operator.py index 25a4bccf..30c71271 100644 --- a/graphgen/bases/base_operator.py +++ b/graphgen/bases/base_operator.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import inspect import os from abc import ABC, abstractmethod -from typing import Iterable, Tuple, Union +from typing import TYPE_CHECKING, Iterable, Tuple, Union -import numpy as np -import pandas as pd -import ray +if TYPE_CHECKING: + import numpy as np + import pandas as pd def convert_to_serializable(obj): + import numpy as np + if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.generic): @@ -40,6 +44,8 @@ def __init__( ) try: + import ray + ctx = ray.get_runtime_context() worker_id = ctx.get_actor_id() or ctx.get_worker_id() worker_id_short = worker_id[-6:] if worker_id else "driver" @@ -62,9 +68,11 @@ def __init__( ) def __call__( - self, batch: pd.DataFrame - ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]: + self, batch: "pd.DataFrame" + ) -> Union["pd.DataFrame", Iterable["pd.DataFrame"]]: # lazy import to avoid circular import + import pandas as pd + from graphgen.utils import CURRENT_LOGGER_VAR logger_token = CURRENT_LOGGER_VAR.set(self.logger) @@ -106,7 +114,7 @@ def get_trace_id(self, content: dict) -> str: return compute_dict_hash(content, prefix=f"{self.op_name}-") - def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: + def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]: """ Split the input batch into to_process & processed based on _meta data in KV_storage :param batch @@ -114,6 +122,8 @@ def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: to_process: DataFrame of documents to be chunked recovered: Result DataFrame of already chunked documents """ + import pandas as pd + meta_forward = self.get_meta_forward() meta_ids = set(meta_forward.keys()) mask = batch["_trace_id"].isin(meta_ids) diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index ba72f410..09354a20 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -1,10 +1,14 @@ +from __future__ import annotations + import os from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union -import pandas as pd import requests -from ray.data import Dataset + +if TYPE_CHECKING: + import pandas as pd + from ray.data import Dataset class BaseReader(ABC): @@ -51,6 +55,7 @@ def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame: """ Validate data format. """ + if "type" not in batch.columns: raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}") diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py index 52604432..56bffedf 100644 --- a/graphgen/common/init_llm.py +++ b/graphgen/common/init_llm.py @@ -1,11 +1,12 @@ import os -from typing import Any, Dict, Optional - -import ray +from typing import TYPE_CHECKING, Any, Dict, Optional from graphgen.bases import BaseLLMWrapper from graphgen.models import Tokenizer +if TYPE_CHECKING: + import ray + class LLMServiceActor: """ @@ -73,7 +74,7 @@ class LLMServiceProxy(BaseLLMWrapper): A proxy class to interact with the LLMServiceActor for distributed LLM operations. """ - def __init__(self, actor_handle: ray.actor.ActorHandle): + def __init__(self, actor_handle: "ray.actor.ActorHandle"): super().__init__() self.actor_handle = actor_handle self._create_local_tokenizer() @@ -120,6 +121,8 @@ class LLMFactory: def create_llm( model_type: str, backend: str, config: Dict[str, Any] ) -> BaseLLMWrapper: + import ray + if not config: raise ValueError( f"No configuration provided for LLM {model_type} with backend {backend}." diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py index 3e32371f..1176d408 100644 --- a/graphgen/common/init_storage.py +++ b/graphgen/common/init_storage.py @@ -146,7 +146,7 @@ def ready(self) -> bool: class RemoteKVStorageProxy(BaseKVStorage): - def __init__(self, actor_handle: ray.actor.ActorHandle): + def __init__(self, actor_handle: "ray.actor.ActorHandle"): super().__init__() self.actor = actor_handle @@ -202,68 +202,87 @@ def get_all_node_degrees(self) -> Dict[str, int]: return ray.get(self.actor.get_all_node_degrees.remote()) def get_node_count(self) -> int: + return ray.get(self.actor.get_node_count.remote()) def get_edge_count(self) -> int: + return ray.get(self.actor.get_edge_count.remote()) def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + return ray.get(self.actor.get_connected_components.remote(undirected)) def has_node(self, node_id: str) -> bool: + return ray.get(self.actor.has_node.remote(node_id)) def has_edge(self, source_node_id: str, target_node_id: str): + return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id)) def node_degree(self, node_id: str) -> int: + return ray.get(self.actor.node_degree.remote(node_id)) def edge_degree(self, src_id: str, tgt_id: str) -> int: + return ray.get(self.actor.edge_degree.remote(src_id, tgt_id)) def get_node(self, node_id: str) -> Any: + return ray.get(self.actor.get_node.remote(node_id)) def update_node(self, node_id: str, node_data: dict[str, str]): + return ray.get(self.actor.update_node.remote(node_id, node_data)) def get_all_nodes(self) -> Any: + return ray.get(self.actor.get_all_nodes.remote()) def get_edge(self, source_node_id: str, target_node_id: str): + return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id)) def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): + return ray.get( self.actor.update_edge.remote(source_node_id, target_node_id, edge_data) ) def get_all_edges(self) -> Any: + return ray.get(self.actor.get_all_edges.remote()) def get_node_edges(self, source_node_id: str) -> Any: + return ray.get(self.actor.get_node_edges.remote(source_node_id)) def upsert_node(self, node_id: str, node_data: dict[str, str]): + return ray.get(self.actor.upsert_node.remote(node_id, node_data)) def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): + return ray.get( self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data) ) def delete_node(self, node_id: str): + return ray.get(self.actor.delete_node.remote(node_id)) def get_neighbors(self, node_id: str) -> List[str]: + return ray.get(self.actor.get_neighbors.remote(node_id)) def reload(self): + return ray.get(self.actor.reload.remote()) @@ -274,6 +293,7 @@ class StorageFactory: @staticmethod def create_storage(backend: str, working_dir: str, namespace: str): + if backend in ["json_kv", "rocksdb"]: actor_name = f"Actor_KV_{namespace}" actor_class = KVStorageActor diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index eed18857..2381d9b1 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,48 +1,121 @@ -from .evaluator import ( - AccuracyEvaluator, - LengthEvaluator, - MTLDEvaluator, - RewardEvaluator, - StructureEvaluator, - UniEvaluator, -) -from .filter import RangeFilter -from .generator import ( - AggregatedGenerator, - AtomicGenerator, - CoTGenerator, - FillInBlankGenerator, - MultiAnswerGenerator, - MultiChoiceGenerator, - MultiHopGenerator, - QuizGenerator, - TrueFalseGenerator, - VQAGenerator, -) -from .kg_builder import LightRAGKGBuilder, MMKGBuilder -from .llm import HTTPClient, OllamaClient, OpenAIClient -from .partitioner import ( - AnchorBFSPartitioner, - BFSPartitioner, - DFSPartitioner, - ECEPartitioner, - LeidenPartitioner, -) -from .reader import ( - CSVReader, - JSONReader, - ParquetReader, - PDFReader, - PickleReader, - RDFReader, - TXTReader, -) -from .rephraser import StyleControlledRephraser -from .searcher.db.ncbi_searcher import NCBISearch -from .searcher.db.rnacentral_searcher import RNACentralSearch -from .searcher.db.uniprot_searcher import UniProtSearch -from .searcher.kg.wiki_search import WikiSearch -from .searcher.web.bing_search import BingSearch -from .searcher.web.google_search import GoogleSearch -from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .tokenizer import Tokenizer +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .evaluator import ( + AccuracyEvaluator, + LengthEvaluator, + MTLDEvaluator, + RewardEvaluator, + StructureEvaluator, + UniEvaluator, + ) + from .filter import RangeFilter + from .generator import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + FillInBlankGenerator, + MultiAnswerGenerator, + MultiChoiceGenerator, + MultiHopGenerator, + QuizGenerator, + TrueFalseGenerator, + VQAGenerator, + ) + from .kg_builder import LightRAGKGBuilder, MMKGBuilder + from .llm import HTTPClient, OllamaClient, OpenAIClient + from .partitioner import ( + AnchorBFSPartitioner, + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, + ) + from .reader import ( + CSVReader, + JSONReader, + ParquetReader, + PDFReader, + PickleReader, + RDFReader, + TXTReader, + ) + from .rephraser import StyleControlledRephraser + from .searcher.db.ncbi_searcher import NCBISearch + from .searcher.db.rnacentral_searcher import RNACentralSearch + from .searcher.db.uniprot_searcher import UniProtSearch + from .searcher.kg.wiki_search import WikiSearch + from .searcher.web.bing_search import BingSearch + from .searcher.web.google_search import GoogleSearch + from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter + from .tokenizer import Tokenizer + +_import_map = { + # Evaluator + "AccuracyEvaluator": ".evaluator", + "LengthEvaluator": ".evaluator", + "MTLDEvaluator": ".evaluator", + "RewardEvaluator": ".evaluator", + "StructureEvaluator": ".evaluator", + "UniEvaluator": ".evaluator", + # Filter + "RangeFilter": ".filter", + # Generator + "AggregatedGenerator": ".generator", + "AtomicGenerator": ".generator", + "CoTGenerator": ".generator", + "FillInBlankGenerator": ".generator", + "MultiAnswerGenerator": ".generator", + "MultiChoiceGenerator": ".generator", + "MultiHopGenerator": ".generator", + "QuizGenerator": ".generator", + "TrueFalseGenerator": ".generator", + "VQAGenerator": ".generator", + # KG Builder + "LightRAGKGBuilder": ".kg_builder", + "MMKGBuilder": ".kg_builder", + # LLM + "HTTPClient": ".llm", + "OllamaClient": ".llm", + "OpenAIClient": ".llm", + # Partitioner + "AnchorBFSPartitioner": ".partitioner", + "BFSPartitioner": ".partitioner", + "DFSPartitioner": ".partitioner", + "ECEPartitioner": ".partitioner", + "LeidenPartitioner": ".partitioner", + # Reader + "CSVReader": ".reader", + "JSONReader": ".reader", + "ParquetReader": ".reader", + "PDFReader": ".reader", + "PickleReader": ".reader", + "RDFReader": ".reader", + "TXTReader": ".reader", + # Searcher + "NCBISearch": ".searcher.db.ncbi_searcher", + "RNACentralSearch": ".searcher.db.rnacentral_searcher", + "UniProtSearch": ".searcher.db.uniprot_searcher", + "WikiSearch": ".searcher.kg.wiki_search", + "BingSearch": ".searcher.web.bing_search", + "GoogleSearch": ".searcher.web.google_search", + # Splitter + "ChineseRecursiveTextSplitter": ".splitter", + "RecursiveCharacterSplitter": ".splitter", + # Tokenizer + "Tokenizer": ".tokenizer", + # Rephraser + "StyleControlledRephraser": ".rephraser", +} + + +def __getattr__(name): + if name in _import_map: + import importlib + + module = importlib.import_module(_import_map[name], package=__name__) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = list(_import_map.keys()) diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py index 380459ad..ee6c8b15 100644 --- a/graphgen/models/evaluator/kg/structure_evaluator.py +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -1,9 +1,6 @@ from collections import Counter from typing import Any, Dict, Optional -import numpy as np -from scipy import stats - from graphgen.bases import BaseGraphStorage, BaseKGEvaluator from graphgen.utils import logger @@ -75,6 +72,9 @@ def evaluate(self, kg: BaseGraphStorage) -> Dict[str, Any]: @staticmethod def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]: + import numpy as np + from scipy import stats + degrees = [deg for deg in degree_map.values() if deg > 0] if len(degrees) < 10: diff --git a/graphgen/models/filter/range_filter.py b/graphgen/models/filter/range_filter.py index 185c19cf..e0444a71 100644 --- a/graphgen/models/filter/range_filter.py +++ b/graphgen/models/filter/range_filter.py @@ -1,9 +1,10 @@ -from typing import Union - -import numpy as np +from typing import TYPE_CHECKING, Union from graphgen.bases import BaseValueFilter +if TYPE_CHECKING: + import numpy as np + class RangeFilter(BaseValueFilter): """ @@ -22,7 +23,7 @@ def __init__( self.left_inclusive = left_inclusive self.right_inclusive = right_inclusive - def filter(self, data: Union[int, float, np.number]) -> bool: + def filter(self, data: Union[int, float, "np.number"]) -> bool: value = float(data) if self.left_inclusive and self.right_inclusive: return self.min_val <= value <= self.max_val diff --git a/graphgen/models/llm/__init__.py b/graphgen/models/llm/__init__.py index c70395d5..d57bd11d 100644 --- a/graphgen/models/llm/__init__.py +++ b/graphgen/models/llm/__init__.py @@ -1,4 +1,27 @@ -from .api.http_client import HTTPClient -from .api.ollama_client import OllamaClient -from .api.openai_client import OpenAIClient -from .local.hf_wrapper import HuggingFaceWrapper +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .api.http_client import HTTPClient + from .api.ollama_client import OllamaClient + from .api.openai_client import OpenAIClient + from .local.hf_wrapper import HuggingFaceWrapper + + +_import_map = { + "HTTPClient": ".api.http_client", + "OllamaClient": ".api.ollama_client", + "OpenAIClient": ".api.openai_client", + "HuggingFaceWrapper": ".local.hf_wrapper", +} + + +def __getattr__(name): + if name in _import_map: + import importlib + + module = importlib.import_module(_import_map[name], package=__name__) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = list(_import_map.keys()) diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index b62b8544..39dc3f40 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -1,12 +1,12 @@ from collections import defaultdict -from typing import Any, Dict, List, Set, Tuple - -import igraph as ig -from leidenalg import ModularityVertexPartition, find_partition +from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple from graphgen.bases import BaseGraphStorage, BasePartitioner from graphgen.bases.datatypes import Community +if TYPE_CHECKING: + import igraph as ig + class LeidenPartitioner(BasePartitioner): """ @@ -62,6 +62,9 @@ def _run_leiden( use_lcc: bool = False, random_seed: int = 42, ) -> Dict[str, int]: + import igraph as ig + from leidenalg import ModularityVertexPartition, find_partition + # build igraph ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False) diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index 2f6ba4c7..98c16384 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,10 +1,11 @@ -from typing import List, Union - -import ray -from ray.data import Dataset +from typing import TYPE_CHECKING, List, Union from graphgen.bases.base_reader import BaseReader +if TYPE_CHECKING: + import ray + from ray.data import Dataset + class CSVReader(BaseReader): """ @@ -14,13 +15,14 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, input_path: Union[str, List[str]]) -> Dataset: + def read(self, input_path: Union[str, List[str]]) -> "Dataset": """ Read CSV files and return Ray Dataset. :param input_path: Path to CSV file or list of CSV files. :return: Ray Dataset containing validated and filtered data. """ + import ray ds = ray.data.read_csv(input_path, include_paths=True) ds = ds.map_batches(self._validate_batch, batch_format="pandas") diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index b8bb7f76..8fcb4d44 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,11 +1,12 @@ import json -from typing import List, Union - -import ray -import ray.data +from typing import TYPE_CHECKING, List, Union from graphgen.bases.base_reader import BaseReader +if TYPE_CHECKING: + import ray + import ray.data + class JSONReader(BaseReader): """ @@ -15,12 +16,14 @@ class JSONReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset: + def read(self, input_path: Union[str, List[str]]) -> "ray.data.Dataset": """ Read JSON file and return Ray Dataset. :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. :return: Ray Dataset containing validated and filtered data. """ + import ray + if self.modalities and len(self.modalities) >= 2: ds: ray.data.Dataset = ray.data.from_items([]) for file in input_path if isinstance(input_path, list) else [input_path]: diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index cc283927..47ae5adc 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -1,10 +1,11 @@ -from typing import List, Union - -import ray -from ray.data import Dataset +from typing import TYPE_CHECKING, List, Union from graphgen.bases.base_reader import BaseReader +if TYPE_CHECKING: + import ray + from ray.data import Dataset + class ParquetReader(BaseReader): """ @@ -14,13 +15,15 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, input_path: Union[str, List[str]]) -> Dataset: + def read(self, input_path: Union[str, List[str]]) -> "Dataset": """ Read Parquet files using Ray Data. :param input_path: Path to Parquet file or list of Parquet files. :return: Ray Dataset containing validated documents. """ + import ray + if not ray.is_initialized(): ray.init() diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 55dab30b..2bfd52f1 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -3,15 +3,16 @@ import subprocess import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import ray -from ray.data import Dataset +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from graphgen.bases.base_reader import BaseReader from graphgen.models.reader.txt_reader import TXTReader from graphgen.utils import logger, pick_device +if TYPE_CHECKING: + import ray + from ray.data import Dataset + class PDFReader(BaseReader): """ @@ -69,7 +70,8 @@ def read( self, input_path: Union[str, List[str]], **override, - ) -> Dataset: + ) -> "Dataset": + import ray # Ensure input_path is a list if isinstance(input_path, str): diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 6e3d1949..5ff92b4d 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -1,13 +1,13 @@ import pickle -from typing import List, Union - -import pandas as pd -import ray -from ray.data import Dataset +from typing import TYPE_CHECKING, List, Union from graphgen.bases.base_reader import BaseReader from graphgen.utils import logger +if TYPE_CHECKING: + import pandas as pd + from ray.data import Dataset + class PickleReader(BaseReader): """ @@ -23,13 +23,16 @@ class PickleReader(BaseReader): def read( self, input_path: Union[str, List[str]], - ) -> Dataset: + ) -> "Dataset": """ Read Pickle files using Ray Data. :param input_path: Path to pickle file or list of pickle files. :return: Ray Dataset containing validated documents. """ + import pandas as pd + import ray + if not ray.is_initialized(): ray.init() @@ -37,7 +40,7 @@ def read( ds = ray.data.read_binary_files(input_path, include_paths=True) # Deserialize pickle files and flatten into individual records - def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: + def deserialize_batch(batch: "pd.DataFrame") -> "pd.DataFrame": all_records = [] for _, row in batch.iterrows(): try: diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index 82e7d572..92f64377 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -1,15 +1,15 @@ from pathlib import Path -from typing import Any, Dict, List, Union - -import ray -import rdflib -from ray.data import Dataset -from rdflib import Literal -from rdflib.util import guess_format +from typing import TYPE_CHECKING, Any, Dict, List, Union from graphgen.bases.base_reader import BaseReader from graphgen.utils import logger +if TYPE_CHECKING: + import ray + import rdflib + from ray.data import Dataset + from rdflib import Literal + class RDFReader(BaseReader): """ @@ -30,13 +30,15 @@ def __init__(self, *, text_column: str = "content", **kwargs): def read( self, input_path: Union[str, List[str]], - ) -> Dataset: + ) -> "Dataset": """ Read RDF file(s) using Ray Data. :param input_path: Path to RDF file or list of RDF files. :return: Ray Dataset containing extracted documents. """ + import ray + if not ray.is_initialized(): ray.init() @@ -73,6 +75,10 @@ def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: :param file_path: Path to RDF file. :return: List of document dictionaries. """ + import rdflib + from rdflib import Literal + from rdflib.util import guess_format + if not file_path.is_file(): raise FileNotFoundError(f"RDF file not found: {file_path}") diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index 784dbe96..c5fc20c2 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -1,21 +1,24 @@ -from typing import List, Union - -import ray -from ray.data import Dataset +from typing import TYPE_CHECKING, List, Union from graphgen.bases.base_reader import BaseReader +if TYPE_CHECKING: + import ray + from ray.data import Dataset + class TXTReader(BaseReader): def read( self, input_path: Union[str, List[str]], - ) -> Dataset: + ) -> "Dataset": """ Read text files from the specified input path. :param input_path: Path to the input text file or list of text files. :return: Ray Dataset containing the read text data. """ + import ray + docs_ds = ray.data.read_binary_files( input_path, include_paths=True, diff --git a/graphgen/models/tokenizer/__init__.py b/graphgen/models/tokenizer/__init__.py index 6712f918..5de9cf04 100644 --- a/graphgen/models/tokenizer/__init__.py +++ b/graphgen/models/tokenizer/__init__.py @@ -4,29 +4,21 @@ from .tiktoken_tokenizer import TiktokenTokenizer -try: - from transformers import AutoTokenizer - - _HF_AVAILABLE = True -except ImportError: - _HF_AVAILABLE = False - def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer: import tiktoken if tokenizer_name in tiktoken.list_encoding_names(): return TiktokenTokenizer(model_name=tokenizer_name) - - # 2. HuggingFace - if _HF_AVAILABLE: + try: + # HuggingFace from .hf_tokenizer import HFTokenizer return HFTokenizer(model_name=tokenizer_name) - - raise ValueError( - f"Unknown tokenizer {tokenizer_name} and HuggingFace not available." - ) + except ImportError as e: + raise ValueError( + f"Unknown tokenizer {tokenizer_name} and HuggingFace not available." + ) from e class Tokenizer(BaseTokenizer): diff --git a/graphgen/models/tokenizer/hf_tokenizer.py b/graphgen/models/tokenizer/hf_tokenizer.py index c43ddd7d..bcb835f8 100644 --- a/graphgen/models/tokenizer/hf_tokenizer.py +++ b/graphgen/models/tokenizer/hf_tokenizer.py @@ -1,13 +1,13 @@ from typing import List -from transformers import AutoTokenizer - from graphgen.bases import BaseTokenizer class HFTokenizer(BaseTokenizer): def __init__(self, model_name: str = "cl100k_base"): super().__init__(model_name) + from transformers import AutoTokenizer + self.enc = AutoTokenizer.from_pretrained(self.model_name) def encode(self, text: str) -> List[int]: diff --git a/graphgen/models/tokenizer/tiktoken_tokenizer.py b/graphgen/models/tokenizer/tiktoken_tokenizer.py index 6145d070..3c641c3b 100644 --- a/graphgen/models/tokenizer/tiktoken_tokenizer.py +++ b/graphgen/models/tokenizer/tiktoken_tokenizer.py @@ -1,13 +1,13 @@ from typing import List -import tiktoken - from graphgen.bases import BaseTokenizer class TiktokenTokenizer(BaseTokenizer): def __init__(self, model_name: str = "cl100k_base"): super().__init__(model_name) + import tiktoken + self.enc = tiktoken.get_encoding(self.model_name) def encode(self, text: str) -> List[int]: diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index 68d67914..fc69f85d 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -1,36 +1,40 @@ import os from functools import lru_cache -from typing import Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union from graphgen.bases import BaseOperator -from graphgen.models import ( - ChineseRecursiveTextSplitter, - RecursiveCharacterSplitter, - Tokenizer, -) from graphgen.utils import detect_main_language -_MAPPING = { - "en": RecursiveCharacterSplitter, - "zh": ChineseRecursiveTextSplitter, -} +if TYPE_CHECKING: + from graphgen.models import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, + Tokenizer, + ) -SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] +if TYPE_CHECKING: + SplitterT = Union["RecursiveCharacterSplitter", "ChineseRecursiveTextSplitter"] +else: + SplitterT = Any @lru_cache(maxsize=None) def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: - cls = _MAPPING[language] kwargs = dict(frozen_kwargs) - return cls(**kwargs) + if language == "en": + from graphgen.models import RecursiveCharacterSplitter + + return RecursiveCharacterSplitter(**kwargs) + if language == "zh": + from graphgen.models import ChineseRecursiveTextSplitter + + return ChineseRecursiveTextSplitter(**kwargs) + raise ValueError( + f"Unsupported language: {language}. Supported languages are: en, zh" + ) def split_chunks(text: str, language: str = "en", **kwargs) -> list: - if language not in _MAPPING: - raise ValueError( - f"Unsupported language: {language}. " - f"Supported languages are: {list(_MAPPING.keys())}" - ) frozen_kwargs = frozenset( (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() ) @@ -45,10 +49,18 @@ def __init__( super().__init__( working_dir=working_dir, kv_backend=kv_backend, op_name="chunk" ) - tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") - self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) + self.tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self._tokenizer_instance: Optional["Tokenizer"] = None self.chunk_kwargs = chunk_kwargs + @property + def tokenizer_instance(self) -> "Tokenizer": + if self._tokenizer_instance is None: + from graphgen.models import Tokenizer + + self._tokenizer_instance = Tokenizer(model_name=self.tokenizer_model) + return self._tokenizer_instance + def process(self, batch: list) -> Tuple[list, dict]: """ Chunk the documents in the batch. diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index 5d362229..dfadf8da 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -3,14 +3,6 @@ from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer from graphgen.common.init_storage import init_storage -from graphgen.models import ( - AnchorBFSPartitioner, - BFSPartitioner, - DFSPartitioner, - ECEPartitioner, - LeidenPartitioner, - Tokenizer, -) from graphgen.utils import logger @@ -31,21 +23,34 @@ def __init__( namespace="graph", ) tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + + from graphgen.models import Tokenizer + self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) method = partition_kwargs["method"] self.method_params = partition_kwargs["method_params"] if method == "bfs": + from graphgen.models import BFSPartitioner + self.partitioner = BFSPartitioner() elif method == "dfs": + from graphgen.models import DFSPartitioner + self.partitioner = DFSPartitioner() elif method == "ece": # before ECE partitioning, we need to: # 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random + from graphgen.models import ECEPartitioner + self.partitioner = ECEPartitioner() elif method == "leiden": + from graphgen.models import LeidenPartitioner + self.partitioner = LeidenPartitioner() elif method == "anchor_bfs": + from graphgen.models import AnchorBFSPartitioner + self.partitioner = AnchorBFSPartitioner( anchor_type=self.method_params.get("anchor_type"), anchor_ids=set(self.method_params.get("anchor_ids", [])) diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index b2f213b3..ec623a76 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -1,7 +1,5 @@ from pathlib import Path -from typing import Any, List, Optional, Union - -import ray +from typing import TYPE_CHECKING, Any, List, Optional, Union from graphgen.common.init_storage import init_storage from graphgen.models import ( @@ -17,6 +15,11 @@ from .parallel_file_scanner import ParallelFileScanner +if TYPE_CHECKING: + import ray + import ray.data + + _MAPPING = { "jsonl": JSONReader, "json": JSONReader, @@ -57,7 +60,7 @@ def read( recursive: bool = True, read_nums: Optional[int] = None, **reader_kwargs: Any, -) -> ray.data.Dataset: +) -> "ray.data.Dataset": """ Unified entry point to read files of multiple types using Ray Data. @@ -71,6 +74,8 @@ def read( :param reader_kwargs: Additional kwargs passed to readers :return: Ray Dataset containing all documents """ + import ray + input_path_cache = init_storage( backend=kv_backend, working_dir=working_dir, namespace="input_path" ) diff --git a/graphgen/operators/search/search_service.py b/graphgen/operators/search/search_service.py index 3f536ee7..7e25e225 100644 --- a/graphgen/operators/search/search_service.py +++ b/graphgen/operators/search/search_service.py @@ -1,12 +1,13 @@ from functools import partial -from typing import Optional - -import pandas as pd +from typing import TYPE_CHECKING, Optional from graphgen.bases import BaseOperator from graphgen.common.init_storage import init_storage from graphgen.utils import compute_content_hash, logger, run_concurrent +if TYPE_CHECKING: + import pandas as pd + class SearchService(BaseOperator): """ @@ -136,7 +137,9 @@ def _process_single_source( return final_results - def process(self, batch: pd.DataFrame) -> pd.DataFrame: + def process(self, batch: "pd.DataFrame") -> "pd.DataFrame": + import pandas as pd + docs = batch.to_dict(orient="records") self._init_searchers() diff --git a/graphgen/storage/graph/networkx_storage.py b/graphgen/storage/graph/networkx_storage.py index 2aa21400..24a9cfba 100644 --- a/graphgen/storage/graph/networkx_storage.py +++ b/graphgen/storage/graph/networkx_storage.py @@ -26,6 +26,7 @@ def get_edge_count(self) -> int: return self._graph.number_of_edges() def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + graph = self._graph if undirected and graph.is_directed(): @@ -36,24 +37,27 @@ def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: ] @staticmethod - def load_nx_graph(file_name) -> Optional[nx.Graph]: + def load_nx_graph(file_name) -> Optional["nx.Graph"]: + if os.path.exists(file_name): return nx.read_graphml(file_name) return None @staticmethod - def write_nx_graph(graph: nx.Graph, file_name): + def write_nx_graph(graph: "nx.Graph", file_name): + nx.write_graphml(graph, file_name) @staticmethod - def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + def stable_largest_connected_component(graph: "nx.Graph") -> "nx.Graph": """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py Return the largest connected component of the graph, with nodes and edges sorted in a stable way. """ + from graspologic.utils import largest_connected_component graph = graph.copy() - graph = cast(nx.Graph, largest_connected_component(graph)) + graph = cast("nx.Graph", largest_connected_component(graph)) node_mapping = { node: html.unescape(node.upper().strip()) for node in graph.nodes() } # type: ignore @@ -61,11 +65,12 @@ def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: return NetworkXStorage._stabilize_graph(graph) @staticmethod - def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + def _stabilize_graph(graph: "nx.Graph") -> "nx.Graph": """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py Ensure an undirected graph with the same relationships will always be read the same way. 通过对节点和边进行排序来实现 """ + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() sorted_nodes = graph.nodes(data=True) @@ -97,6 +102,7 @@ def __post_init__(self): Initialize the NetworkX graph storage by loading an existing graph from a GraphML file, if it exists, or creating a new empty graph otherwise. """ + self._graphml_xml_file = os.path.join( self.working_dir, f"{self.namespace}.graphml" ) @@ -141,7 +147,7 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No return list(self._graph.edges(source_node_id, data=True)) return None - def get_graph(self) -> nx.Graph: + def get_graph(self) -> "nx.Graph": return self._graph def upsert_node(self, node_id: str, node_data: dict[str, any]):