22
33from typing import Any , Dict , List , Optional
44
5- from langchain import BasePromptTemplate , LLMChain
65from langchain .base_language import BaseLanguageModel
76from langchain .callbacks .manager import (
87 AsyncCallbackManagerForChainRun ,
98 CallbackManagerForChainRun ,
109)
10+ from langchain .chains import LLMChain
1111from langchain .chains .base import Chain
1212from langchain .output_parsers import OutputFixingParser , PydanticOutputParser
1313from langchain .schema import BaseOutputParser
14+ from langchain_core .prompts import BasePromptTemplate
1415from pydantic import Extra , Field
1516
1617from codedog .chains .pr_summary .prompts import CODE_SUMMARY_PROMPT , PR_SUMMARY_PROMPT
2021 PullRequestProcessor ,
2122)
2223
24+ processor = PullRequestProcessor .build ()
25+
2326
2427class PRSummaryChain (Chain ):
2528 """Summarize a pull request.
@@ -32,17 +35,13 @@ class PRSummaryChain(Chain):
3235 - code_summaries(Dict[str, str]): changed code file summarizations, key is file path.
3336 """
3437
35- # TODO: input keys validation
36-
3738 code_summary_chain : LLMChain = Field (exclude = True )
3839 """Chain to use to summarize code change."""
3940 pr_summary_chain : LLMChain = Field (exclude = True )
4041 """Chain to use to summarize PR."""
4142
4243 parser : BaseOutputParser = Field (exclude = True )
4344 """Parse pr summarized result to PRSummary object."""
44- processor : PullRequestProcessor = Field (exclude = True , default_factory = PullRequestProcessor .build )
45- """PR data process."""
4645
4746 _input_keys : List [str ] = ["pull_request" ]
4847 _output_keys : List [str ] = ["pr_summary" , "code_summaries" ]
@@ -78,15 +77,21 @@ def review(self, inputs, _run_manager) -> Dict[str, Any]:
7877
7978 code_summary_inputs = self ._process_code_summary_inputs (pr )
8079 code_summary_outputs = (
81- self .code_summary_chain .apply (code_summary_inputs , callbacks = _run_manager .get_child (tag = "CodeSummary" ))
80+ self .code_summary_chain .apply (
81+ code_summary_inputs , callbacks = _run_manager .get_child (tag = "CodeSummary" )
82+ )
8283 if code_summary_inputs
8384 else []
8485 )
8586
86- code_summaries = self .processor .build_change_summaries (code_summary_inputs , code_summary_outputs )
87+ code_summaries = processor .build_change_summaries (
88+ code_summary_inputs , code_summary_outputs
89+ )
8790
8891 pr_summary_input = self ._process_pr_summary_input (pr , code_summaries )
89- pr_summary_output = self .pr_summary_chain (pr_summary_input , callbacks = _run_manager .get_child (tag = "PRSummary" ))
92+ pr_summary_output = self .pr_summary_chain (
93+ pr_summary_input , callbacks = _run_manager .get_child (tag = "PRSummary" )
94+ )
9095
9196 return self ._process_result (pr_summary_output , code_summaries )
9297
@@ -95,26 +100,38 @@ async def areview(self, inputs, _run_manager) -> Dict[str, Any]:
95100
96101 code_summary_inputs = self ._process_code_summary_inputs (pr )
97102 code_summary_outputs = (
98- await self .code_summary_chain .aapply (code_summary_inputs , callbacks = _run_manager .get_child ())
103+ await self .code_summary_chain .aapply (
104+ code_summary_inputs , callbacks = _run_manager .get_child ()
105+ )
99106 if code_summary_inputs
100107 else []
101108 )
102109
103- code_summaries = self .processor .build_change_summaries (code_summary_inputs , code_summary_outputs )
110+ code_summaries = processor .build_change_summaries (
111+ code_summary_inputs , code_summary_outputs
112+ )
104113
105114 pr_summary_input = self ._process_pr_summary_input (pr , code_summaries )
106- pr_summary_output = await self .pr_summary_chain .acall (pr_summary_input , callbacks = _run_manager .get_child ())
115+ pr_summary_output = await self .pr_summary_chain .ainvoke (
116+ pr_summary_input , callbacks = _run_manager .get_child ()
117+ )
107118
108119 return await self ._aprocess_result (pr_summary_output , code_summaries )
109120
110- def _call (self , inputs : Dict [str , Any ], run_manager : Optional [CallbackManagerForChainRun ] = None ) -> Dict [str , Any ]:
121+ def _call (
122+ self ,
123+ inputs : Dict [str , Any ],
124+ run_manager : Optional [CallbackManagerForChainRun ] = None ,
125+ ) -> Dict [str , Any ]:
111126 _run_manager = run_manager or CallbackManagerForChainRun .get_noop_manager ()
112127 _run_manager .on_text (inputs ["pull_request" ].json () + "\n " )
113128
114129 return self .review (inputs , _run_manager )
115130
116131 async def _acall (
117- self , inputs : Dict [str , Any ], run_manager : Optional [AsyncCallbackManagerForChainRun ] = None
132+ self ,
133+ inputs : Dict [str , Any ],
134+ run_manager : Optional [AsyncCallbackManagerForChainRun ] = None ,
118135 ) -> Dict [str , Any ]:
119136 _run_manager = run_manager or CallbackManagerForChainRun .get_noop_manager ()
120137 await _run_manager .on_text (inputs ["pull_request" ].json () + "\n " )
@@ -123,28 +140,36 @@ async def _acall(
123140
124141 def _process_code_summary_inputs (self , pr : PullRequest ) -> List [Dict [str , str ]]:
125142 input_data = []
126- code_files = self . processor .get_diff_code_files (pr )
143+ code_files = processor .get_diff_code_files (pr )
127144 for code_file in code_files :
128145 input_item = {
129- "content" : code_file .diff_content .content [:2000 ], # TODO: handle long diff
146+ "content" : code_file .diff_content .content [
147+ :2000
148+ ], # TODO: handle long diff
130149 "name" : code_file .full_name ,
131150 "language" : SUFFIX_LANGUAGE_MAPPING .get (code_file .suffix , "" ),
132151 }
133152 input_data .append (input_item )
134153
135154 return input_data
136155
137- def _process_pr_summary_input (self , pr : PullRequest , code_summaries : List [ChangeSummary ]) -> Dict [str , str ]:
138- change_files_material : str = self .processor .gen_material_change_files (pr .change_files )
139- code_summaries_material = self .processor .gen_material_code_summaries (code_summaries )
140- pr_metadata_material = self .processor .gen_material_pr_metadata (pr )
156+ def _process_pr_summary_input (
157+ self , pr : PullRequest , code_summaries : List [ChangeSummary ]
158+ ) -> Dict [str , str ]:
159+ change_files_material : str = processor .gen_material_change_files (
160+ pr .change_files
161+ )
162+ code_summaries_material = processor .gen_material_code_summaries (code_summaries )
163+ pr_metadata_material = processor .gen_material_pr_metadata (pr )
141164 return {
142165 "change_files" : change_files_material ,
143166 "code_summaries" : code_summaries_material ,
144167 "metadata" : pr_metadata_material ,
145168 }
146169
147- def _process_result (self , pr_summary_output : Dict [str , Any ], code_summaries : List [ChangeSummary ]) -> Dict [str , Any ]:
170+ def _process_result (
171+ self , pr_summary_output : Dict [str , Any ], code_summaries : List [ChangeSummary ]
172+ ) -> Dict [str , Any ]:
148173 return {
149174 "pr_summary" : pr_summary_output ["text" ],
150175 "code_summaries" : code_summaries ,
@@ -167,7 +192,16 @@ def from_llm(
167192 pr_summary_prompt : BasePromptTemplate = PR_SUMMARY_PROMPT ,
168193 ** kwargs ,
169194 ) -> PRSummaryChain :
170- parser = OutputFixingParser .from_llm (llm = pr_summary_llm , parser = PydanticOutputParser (pydantic_object = PRSummary ))
195+ parser = OutputFixingParser .from_llm (
196+ llm = pr_summary_llm , parser = PydanticOutputParser (pydantic_object = PRSummary )
197+ )
171198 code_summary_chain = LLMChain (llm = code_summary_llm , prompt = code_summary_prompt )
172- pr_summary_chain = LLMChain (llm = pr_summary_llm , prompt = pr_summary_prompt , output_parser = parser )
173- return cls (code_summary_chain = code_summary_chain , pr_summary_chain = pr_summary_chain , parser = parser , ** kwargs )
199+ pr_summary_chain = LLMChain (
200+ llm = pr_summary_llm , prompt = pr_summary_prompt , output_parser = parser
201+ )
202+ return cls (
203+ code_summary_chain = code_summary_chain ,
204+ pr_summary_chain = pr_summary_chain ,
205+ parser = parser ,
206+ ** kwargs ,
207+ )
0 commit comments