77
88class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
99 def train (
10- self ,
11- graph_name : str ,
12- model_name : str ,
13- feature_properties : List [str ],
14- target_property : str ,
15- relationship_types : List [str ],
16- target_node_label : str = None ,
17- node_labels : List [str ] = None ,
10+ self ,
11+ graph_name : str ,
12+ model_name : str ,
13+ feature_properties : List [str ],
14+ target_property : str ,
15+ relationship_types : List [str ],
16+ target_node_label : str = None ,
17+ node_labels : List [str ] = None ,
1818 ) -> "Series[Any]" : # noqa: F821
1919 mlConfigMap = {
2020 "featureProperties" : feature_properties ,
@@ -40,26 +40,18 @@ def train(
4040 )
4141
4242 def predict (
43- self ,
44- graph_name : str ,
45- model_name : str ,
46- feature_properties : List [str ],
47- relationship_types : List [str ],
48- mutateProperty : str ,
49- target_node_label : str = None ,
50- node_labels : List [str ] = None ,
43+ self ,
44+ graph_name : str ,
45+ model_name : str ,
46+ mutateProperty : str ,
47+ predictedProbabilityProperty : str = None ,
5148 ) -> "Series[Any]" : # noqa: F821
5249 mlConfigMap = {
53- "featureProperties" : feature_properties ,
5450 "job_type" : "predict" ,
55- "nodeProperties" : feature_properties ,
56- "relationshipTypes" : relationship_types ,
5751 "mutateProperty" : mutateProperty
5852 }
59- if target_node_label :
60- mlConfigMap ["targetNodeLabel" ] = target_node_label
61- if node_labels :
62- mlConfigMap ["nodeLabels" ] = node_labels
53+ if predictedProbabilityProperty :
54+ mlConfigMap ["predictedProbabilityProperty" ] = predictedProbabilityProperty
6355
6456 mlTrainingConfig = json .dumps (mlConfigMap )
6557 self ._query_runner .run_query (
0 commit comments