Skip to content

Commit b482106

Browse files
feat(extgen): make the generator idempotent and avoid touching the original source
1 parent 91c553f commit b482106

File tree

11 files changed

+832
-215
lines changed

11 files changed

+832
-215
lines changed

docs/extensions.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,18 @@ GEN_STUB_SCRIPT=php-src/build/gen_stub.php frankenphp extension-init my_extensio
587587
> [!NOTE]
588588
> Don't forget to set the `GEN_STUB_SCRIPT` environment variable to the path of the `gen_stub.php` file in the PHP sources you downloaded earlier. This is the same `gen_stub.php` script mentioned in the manual implementation section.
589589
590-
If everything went well, a new directory named `build` should have been created. This directory contains the generated files for your extension, including the `my_extension.go` file with the generated PHP function stubs.
590+
If everything went well, your project directory should contain the following files for your extension:
591+
592+
- **`my_extension.go`** - Your original source file (remains unchanged)
593+
- **`my_extension_generated.go`** - Generated file with CGO wrappers that call your functions
594+
- **`my_extension.stub.php`** - PHP stub file for IDE autocompletion
595+
- **`my_extension_arginfo.h`** - PHP argument information
596+
- **`my_extension.h`** - C header file
597+
- **`my_extension.c`** - C implementation file
598+
- **`README.md`** - Documentation
599+
600+
> [!IMPORTANT]
601+
> **Your source file (`my_extension.go`) is never modified.** The generator creates a separate `_generated.go` file containing CGO wrappers that call your original functions. This means you can safely version control your source file without worrying about generated code polluting it.
591602
592603
### Integrating the Generated Extension into FrankenPHP
593604

internal/extgen/gofile.go

Lines changed: 134 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import (
44
"bytes"
55
_ "embed"
66
"fmt"
7-
"os"
7+
"go/format"
88
"path/filepath"
9+
"strings"
910
"text/template"
1011

1112
"github.com/Masterminds/sprig/v3"
@@ -21,7 +22,7 @@ type GoFileGenerator struct {
2122
type goTemplateData struct {
2223
PackageName string
2324
BaseName string
24-
Imports []string
25+
SanitizedBaseName string
2526
Constants []phpConstant
2627
Variables []string
2728
InternalFunctions []string
@@ -30,16 +31,7 @@ type goTemplateData struct {
3031
}
3132

3233
func (gg *GoFileGenerator) generate() error {
33-
filename := filepath.Join(gg.generator.BuildDir, gg.generator.BaseName+".go")
34-
35-
if _, err := os.Stat(filename); err == nil {
36-
backupFilename := filename + ".bak"
37-
if err := os.Rename(filename, backupFilename); err != nil {
38-
return fmt.Errorf("backing up existing Go file: %w", err)
39-
}
40-
41-
gg.generator.SourceFile = backupFilename
42-
}
34+
filename := filepath.Join(gg.generator.BuildDir, gg.generator.BaseName+"_generated.go")
4335

4436
content, err := gg.buildContent()
4537
if err != nil {
@@ -51,38 +43,18 @@ func (gg *GoFileGenerator) generate() error {
5143

5244
func (gg *GoFileGenerator) buildContent() (string, error) {
5345
sourceAnalyzer := SourceAnalyzer{}
54-
imports, variables, internalFunctions, err := sourceAnalyzer.analyze(gg.generator.SourceFile)
46+
packageName, variables, internalFunctions, err := sourceAnalyzer.analyze(gg.generator.SourceFile)
5547
if err != nil {
5648
return "", fmt.Errorf("analyzing source file: %w", err)
5749
}
5850

59-
filteredImports := make([]string, 0, len(imports))
60-
for _, imp := range imports {
61-
if imp != `"C"` && imp != `"unsafe"` && imp != `"github.com/dunglas/frankenphp"` && imp != `"runtime/cgo"` {
62-
filteredImports = append(filteredImports, imp)
63-
}
64-
}
65-
6651
classes := make([]phpClass, len(gg.generator.Classes))
6752
copy(classes, gg.generator.Classes)
6853

69-
if len(classes) > 0 {
70-
hasCgo := false
71-
for _, imp := range imports {
72-
if imp == `"runtime/cgo"` {
73-
hasCgo = true
74-
break
75-
}
76-
}
77-
if !hasCgo {
78-
filteredImports = append(filteredImports, `"runtime/cgo"`)
79-
}
80-
}
81-
8254
templateContent, err := gg.getTemplateContent(goTemplateData{
83-
PackageName: SanitizePackageName(gg.generator.BaseName),
55+
PackageName: packageName,
8456
BaseName: gg.generator.BaseName,
85-
Imports: filteredImports,
57+
SanitizedBaseName: SanitizePackageName(gg.generator.BaseName),
8658
Constants: gg.generator.Constants,
8759
Variables: variables,
8860
InternalFunctions: internalFunctions,
@@ -94,7 +66,12 @@ func (gg *GoFileGenerator) buildContent() (string, error) {
9466
return "", fmt.Errorf("executing template: %w", err)
9567
}
9668

97-
return templateContent, nil
69+
fc, err := format.Source([]byte(templateContent))
70+
if err != nil {
71+
return "", fmt.Errorf("formatting source: %w", err)
72+
}
73+
74+
return string(fc), nil
9875
}
9976

10077
func (gg *GoFileGenerator) getTemplateContent(data goTemplateData) (string, error) {
@@ -106,6 +83,10 @@ func (gg *GoFileGenerator) getTemplateContent(data goTemplateData) (string, erro
10683
funcMap["isVoid"] = func(t phpType) bool {
10784
return t == phpVoid
10885
}
86+
funcMap["extractGoFunctionName"] = extractGoFunctionName
87+
funcMap["extractGoFunctionSignatureParams"] = extractGoFunctionSignatureParams
88+
funcMap["extractGoFunctionSignatureReturn"] = extractGoFunctionSignatureReturn
89+
funcMap["extractGoFunctionCallParams"] = extractGoFunctionCallParams
10990

11091
tmpl := template.Must(template.New("gofile").Funcs(funcMap).Parse(goFileContent))
11192

@@ -128,7 +109,7 @@ type GoParameter struct {
128109
Type string
129110
}
130111

131-
var phpToGoTypeMap= map[phpType]string{
112+
var phpToGoTypeMap = map[phpType]string{
132113
phpString: "string",
133114
phpInt: "int64",
134115
phpFloat: "float64",
@@ -146,3 +127,119 @@ func (gg *GoFileGenerator) phpTypeToGoType(phpT phpType) string {
146127

147128
return "any"
148129
}
130+
131+
// extractGoFunctionName extracts the Go function name from a Go function signature string.
132+
func extractGoFunctionName(goFunction string) string {
133+
idx := strings.Index(goFunction, "func ")
134+
if idx == -1 {
135+
return ""
136+
}
137+
138+
start := idx + len("func ")
139+
140+
end := start
141+
for end < len(goFunction) && goFunction[end] != '(' {
142+
end++
143+
}
144+
145+
if end >= len(goFunction) {
146+
return ""
147+
}
148+
149+
return strings.TrimSpace(goFunction[start:end])
150+
}
151+
152+
// extractGoFunctionSignatureParams extracts the parameters from a Go function signature.
153+
func extractGoFunctionSignatureParams(goFunction string) string {
154+
start := strings.IndexByte(goFunction, '(')
155+
if start == -1 {
156+
return ""
157+
}
158+
start++
159+
160+
depth := 1
161+
end := start
162+
for end < len(goFunction) && depth > 0 {
163+
switch goFunction[end] {
164+
case '(':
165+
depth++
166+
case ')':
167+
depth--
168+
}
169+
if depth > 0 {
170+
end++
171+
}
172+
}
173+
174+
if end >= len(goFunction) {
175+
return ""
176+
}
177+
178+
return strings.TrimSpace(goFunction[start:end])
179+
}
180+
181+
// extractGoFunctionSignatureReturn extracts the return type from a Go function signature.
182+
func extractGoFunctionSignatureReturn(goFunction string) string {
183+
start := strings.IndexByte(goFunction, '(')
184+
if start == -1 {
185+
return ""
186+
}
187+
188+
depth := 1
189+
pos := start + 1
190+
for pos < len(goFunction) && depth > 0 {
191+
switch goFunction[pos] {
192+
case '(':
193+
depth++
194+
case ')':
195+
depth--
196+
}
197+
pos++
198+
}
199+
200+
if pos >= len(goFunction) {
201+
return ""
202+
}
203+
204+
end := strings.IndexByte(goFunction[pos:], '{')
205+
if end == -1 {
206+
return ""
207+
}
208+
end += pos
209+
210+
returnType := strings.TrimSpace(goFunction[pos:end])
211+
return returnType
212+
}
213+
214+
// extractGoFunctionCallParams extracts just the parameter names for calling a function.
215+
func extractGoFunctionCallParams(goFunction string) string {
216+
params := extractGoFunctionSignatureParams(goFunction)
217+
if params == "" {
218+
return ""
219+
}
220+
221+
var names []string
222+
parts := strings.Split(params, ",")
223+
for _, part := range parts {
224+
part = strings.TrimSpace(part)
225+
if len(part) == 0 {
226+
continue
227+
}
228+
229+
words := strings.Fields(part)
230+
if len(words) > 0 {
231+
names = append(names, words[0])
232+
}
233+
}
234+
235+
var result strings.Builder
236+
for i, name := range names {
237+
if i > 0 {
238+
result.WriteString(", ")
239+
}
240+
241+
result.WriteString(name)
242+
}
243+
244+
return result.String()
245+
}

0 commit comments

Comments
 (0)