Skip to content

Commit f3850e2

Browse files
committed
✨ Add encoding parameter to write_file()
1 parent 59a3c34 commit f3850e2

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

bibtexparser/entrypoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def write_file(
139139
unparse_stack: Optional[Iterable[Middleware]] = None,
140140
prepend_middleware: Optional[Iterable[Middleware]] = None,
141141
bibtex_format: Optional[BibtexFormat] = None,
142+
encoding: str = "UTF-8",
142143
) -> None:
143144
"""Write a BibTeX database to a file.
144145
@@ -148,15 +149,16 @@ def write_file(
148149
If None, a default stack will be used.
149150
:param prepend_middleware: List of middleware to prepend to the default stack.
150151
Only applicable if `unparse_stack` is None.
151-
:param bibtex_format: Customized BibTeX format to use (optional)."""
152+
:param bibtex_format: Customized BibTeX format to use (optional).
153+
:param encoding: Encoding of the .bib file. Default encoding is ``"UTF-8"``."""
152154
bibtex_str = write_string(
153155
library=library,
154156
unparse_stack=unparse_stack,
155157
prepend_middleware=prepend_middleware,
156158
bibtex_format=bibtex_format,
157159
)
158160
if isinstance(file, str):
159-
with open(file, "w") as f:
161+
with open(file, "w", encoding=encoding) as f:
160162
f.write(bibtex_str)
161163
else:
162164
file.write(bibtex_str)

tests/test_entrypoint.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
"""Testing the parse_file function."""
1+
"""Testing the parse_file and write_file functions."""
2+
3+
import os
4+
import tempfile
25

36
from bibtexparser import parse_file
7+
from bibtexparser import parse_string
8+
from bibtexparser import write_file
9+
from bibtexparser.model import Entry
10+
from bibtexparser.model import Field
11+
from bibtexparser.library import Library
412

513

614
def test_gbk():
@@ -9,3 +17,77 @@ def test_gbk():
917
assert library.entries[0]["title"] == "Test Title"
1018
assert library.entries[0]["year"] == "2013"
1119
assert library.entries[0]["journal"] == "测试期刊"
20+
21+
22+
def test_write_file_default_encoding():
23+
"""Test write_file uses UTF-8 by default."""
24+
entry = Entry(
25+
entry_type="article",
26+
key="test2024",
27+
fields=[
28+
Field(key="author", value="Müller"),
29+
Field(key="title", value="Ångström measurements"),
30+
],
31+
)
32+
library = Library([entry])
33+
34+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bib", delete=False) as f:
35+
temp_path = f.name
36+
37+
try:
38+
write_file(temp_path, library)
39+
# Read back and verify
40+
with open(temp_path, encoding="UTF-8") as f:
41+
content = f.read()
42+
assert "Müller" in content
43+
assert "Ångström" in content
44+
finally:
45+
os.unlink(temp_path)
46+
47+
48+
def test_write_file_gbk_encoding():
49+
"""Test write_file with GBK encoding for Chinese characters."""
50+
entry = Entry(
51+
entry_type="article",
52+
key="test2024",
53+
fields=[
54+
Field(key="author", value="凯撒"),
55+
Field(key="title", value="Test Title"),
56+
Field(key="journal", value="测试期刊"),
57+
],
58+
)
59+
library = Library([entry])
60+
61+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bib", delete=False) as f:
62+
temp_path = f.name
63+
64+
try:
65+
write_file(temp_path, library, encoding="gbk")
66+
# Read back with GBK and verify
67+
with open(temp_path, encoding="gbk") as f:
68+
content = f.read()
69+
assert "凯撒" in content
70+
assert "测试期刊" in content
71+
finally:
72+
os.unlink(temp_path)
73+
74+
75+
def test_write_file_roundtrip_gbk():
76+
"""Test round-trip: parse GBK file, write with GBK, parse again."""
77+
# Parse original GBK file
78+
library = parse_file("tests/resources/gbk_test.bib", encoding="gbk")
79+
original_author = library.entries[0]["author"]
80+
original_journal = library.entries[0]["journal"]
81+
82+
with tempfile.NamedTemporaryFile(mode="w", suffix=".bib", delete=False) as f:
83+
temp_path = f.name
84+
85+
try:
86+
# Write with GBK encoding
87+
write_file(temp_path, library, encoding="gbk")
88+
# Parse back
89+
library2 = parse_file(temp_path, encoding="gbk")
90+
assert library2.entries[0]["author"] == original_author
91+
assert library2.entries[0]["journal"] == original_journal
92+
finally:
93+
os.unlink(temp_path)

0 commit comments

Comments
 (0)