Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ authors:
abstract: PyThaiNLP is a Thai natural language processing library for Python. It provides standard linguistic analysis for the Thai language, including tokenization and part-of-speech tagging. Additionally, it offers standard Thai locale utility functions, such as Thai Buddhist Era date formatting and the conversion of numbers into Thai text.
repository-code: "https://github.com/PyThaiNLP/pythainlp"
type: software
doi: 10.5281/zenodo.3519354
version: 5.2.0
license-url: "https://spdx.org/licenses/Apache-2.0"
keywords:
Expand Down
48 changes: 28 additions & 20 deletions pythainlp/translate/en_th.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from __future__ import annotations

import os
import warnings

try:
from fairseq.models.transformer import TransformerModel
Expand Down Expand Up @@ -126,26 +127,33 @@ def __init__(self, use_gpu: bool = False):
self._model_name = _TH_EN_MODEL_NAME

_download_install(self._model_name)
self._model = TransformerModel.from_pretrained(
model_name_or_path=_get_translate_path(
self._model_name,
_TH_EN_FILE_NAME,
"models",
),
checkpoint_file="checkpoint.pt",
data_name_or_path=_get_translate_path(
self._model_name,
_TH_EN_FILE_NAME,
"vocab",
),
bpe="sentencepiece",
sentencepiece_model=_get_translate_path(
self._model_name,
_TH_EN_FILE_NAME,
"bpe",
"spm.th.model",
),
)
# Suppress model type mismatch warning from transformers
# The pre-trained model has camembert config but works fine
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="(?i).*using a model of type .* to instantiate a model of type.*",
)
self._model = TransformerModel.from_pretrained(
model_name_or_path=_get_translate_path(
self._model_name,
_TH_EN_FILE_NAME,
"models",
),
checkpoint_file="checkpoint.pt",
data_name_or_path=_get_translate_path(
self._model_name,
_TH_EN_FILE_NAME,
"vocab",
),
bpe="sentencepiece",
sentencepiece_model=_get_translate_path(
self._model_name,
_TH_EN_FILE_NAME,
"bpe",
"spm.th.model",
),
)
if use_gpu:
self._model.cuda()

Expand Down
Loading