Skip to content

Commit c883542

Browse files
committed
Add transformers_ud
1 parent e1d1b34 commit c883542

File tree

4 files changed

+126
-2
lines changed

4 files changed

+126
-2
lines changed

pythainlp/parse/core.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
_tagger = None
33
_tagger_name = ""
44

5-
def dependency_parsing(text: str, engine: str="esupar")->str:
5+
def dependency_parsing(text: str, model: str=None, engine: str="esupar")->str:
66
"""
77
Dependency Parsing
88
99
:param str text: text to do dependency parsing
10+
:param str model: model for using with engine \
11+
(for esupar and transformers_ud)
1012
:param str engine: the name dependency parser
1113
:return: str (conllu)
1214
@@ -17,6 +19,41 @@ def dependency_parsing(text: str, engine: str="esupar")->str:
1719
* *spacy_thai* - Tokenizer, POS-tagger, and dependency-parser \
1820
for Thai language, working on Universal Dependencies. \
1921
`GitHub <https://github.com/KoichiYasuoka/spacy-thai>`_
22+
* *transformers_ud* - TransformersUD \
23+
`GitHub <https://github.com/KoichiYasuoka/>`_
24+
25+
**Options for model (esupar engine)**
26+
* *th* (default) - KoichiYasuoka/roberta-base-thai-spm-upos model \
27+
`Huggingface \
28+
<https://huggingface.co/KoichiYasuoka/roberta-base-thai-spm-upos>`_
29+
* *KoichiYasuoka/deberta-base-thai-upos* - DeBERTa(V2) model \
30+
pre-trained on Thai Wikipedia texts for POS-tagging and \
31+
dependency-parsing `Huggingface \
32+
<https://huggingface.co/KoichiYasuoka/deberta-base-thai-upos>`_
33+
* *KoichiYasuoka/roberta-base-thai-syllable-upos* - RoBERTa model \
34+
pre-trained on Thai Wikipedia texts for POS-tagging and \
35+
dependency-parsing. (syllable level) `Huggingface \
36+
<https://huggingface.co/KoichiYasuoka/roberta-base-thai-syllable-upos>`_
37+
* *KoichiYasuoka/roberta-base-thai-char-upos* - RoBERTa model \
38+
pre-trained on Thai Wikipedia texts for POS-tagging \
39+
and dependency-parsing. (char level) `Huggingface \
40+
<https://huggingface.co/KoichiYasuoka/roberta-base-thai-char-upos>`_
41+
42+
If you want to train model for esupar, you can read \
43+
`Huggingface <https://github.com/KoichiYasuoka/esupar>`_
44+
45+
**Options for model (transformers_ud engine)**
46+
* *KoichiYasuoka/deberta-base-thai-ud-head* (default) - \
47+
DeBERTa(V2) model pretrained on Thai Wikipedia texts \
48+
for dependency-parsing (head-detection on Universal \
49+
Dependencies) as question-answering, derived from \
50+
deberta-base-thai. \
51+
trained by th_blackboard.conll. `Huggingface \
52+
<https://huggingface.co/KoichiYasuoka/deberta-base-thai-ud-head>`_
53+
* *KoichiYasuoka/roberta-base-thai-spm-ud-head* - \
54+
roberta model pretrained on Thai Wikipedia texts \
55+
for dependency-parsing. `Huggingface \
56+
<https://huggingface.co/KoichiYasuoka/roberta-base-thai-spm-ud-head>`_
2057
2158
:Example:
2259
::
@@ -40,7 +77,10 @@ def dependency_parsing(text: str, engine: str="esupar")->str:
4077
if _tagger_name != engine:
4178
if engine == "esupar":
4279
from pythainlp.parse.esupar_engine import Parse
43-
_tagger = Parse()
80+
_tagger = Parse(model=model)
81+
elif engine == "transformers_ud":
82+
from pythainlp.parse.transformers_ud import Parse
83+
_tagger = Parse(model=model)
4484
elif engine == "spacy_thai":
4585
from pythainlp.parse.spacy_thai_engine import Parse
4686
_tagger = Parse()

pythainlp/parse/esupar_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
class Parse:
1111
def __init__(self, model: str="th") -> None:
12+
if model == None:
13+
model = "th"
1214
self.nlp=esupar.load(model)
1315

1416
def __call__(self, text):

pythainlp/parse/transformers_ud.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
TransformersUD
4+
5+
Author: Prof. Koichi Yasuoka
6+
7+
This tagger is provided under the terms of the apache-2.0 License.
8+
9+
The source: https://huggingface.co/KoichiYasuoka/deberta-base-thai-ud-head
10+
11+
GitHub: https://github.com/KoichiYasuoka
12+
"""
13+
import os
14+
import numpy
15+
import torch
16+
import ufal.chu_liu_edmonds
17+
from transformers import (
18+
AutoTokenizer,
19+
AutoModelForQuestionAnswering,
20+
AutoModelForTokenClassification,
21+
AutoConfig,
22+
TokenClassificationPipeline
23+
)
24+
from transformers.utils import cached_file
25+
26+
27+
class Parse:
28+
def __init__(self, model: str="KoichiYasuoka/deberta-base-thai-ud-head") -> None:
29+
if model == None:
30+
model = "KoichiYasuoka/deberta-base-thai-ud-head"
31+
self.tokenizer=AutoTokenizer.from_pretrained(model)
32+
self.model=AutoModelForQuestionAnswering.from_pretrained(model)
33+
x=AutoModelForTokenClassification.from_pretrained
34+
if os.path.isdir(model):
35+
d,t=x(os.path.join(model,"deprel")),x(os.path.join(model,"tagger"))
36+
else:
37+
c=AutoConfig.from_pretrained(cached_file(model,"deprel/config.json"))
38+
d=x(cached_file(model,"deprel/pytorch_model.bin"),config=c)
39+
s=AutoConfig.from_pretrained(cached_file(model,"tagger/config.json"))
40+
t=x(cached_file(model,"tagger/pytorch_model.bin"),config=s)
41+
self.deprel=TokenClassificationPipeline(
42+
model=d
43+
tokenizer=self.tokenizer,
44+
aggregation_strategy="simple"
45+
)
46+
self.tagger=TokenClassificationPipeline(
47+
model=t
48+
tokenizer=self.tokenizer
49+
)
50+
51+
def __call__(self, text: str)->str:
52+
w=[(t["start"],t["end"],t["entity_group"]) for t in self.deprel(text)]
53+
z,n={t["start"]:t["entity"].split("|") for t in self.tagger(text)},len(w)
54+
r,m=[text[s:e] for s,e,p in w],numpy.full((n+1,n+1),numpy.nan)
55+
v,c=self.tokenizer(r,add_special_tokens=False)["input_ids"],[]
56+
for i,t in enumerate(v):
57+
q=[self.tokenizer.cls_token_id]+t+[self.tokenizer.sep_token_id]
58+
c.append([q]+v[0:i]+[[self.tokenizer.mask_token_id]]+v[i+1:]+[[q[-1]]])
59+
b=[[len(sum(x[0:j+1],[])) for j in range(len(x))] for x in c]
60+
with torch.no_grad():
61+
d=self.model(
62+
input_ids=torch.tensor([sum(x,[]) for x in c]),
63+
token_type_ids=torch.tensor([[0]*x[0]+[1]*(x[-1]-x[0]) for x in b])
64+
)
65+
s,e=d.start_logits.tolist(),d.end_logits.tolist()
66+
for i in range(n):
67+
for j in range(n):
68+
m[i+1,0 if i==j else j+1]=s[i][b[i][j]]+e[i][b[i][j+1]-1]
69+
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
70+
if [0 for i in h if i==0]!=[0]:
71+
i=([p for s,e,p in w]+["root"]).index("root")
72+
j=i+1 if i<n else numpy.nanargmax(m[:,0])
73+
m[0:j,0]=m[j+1:,0]=numpy.nan
74+
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
75+
u="# text = "+text.replace("\n"," ")+"\n"
76+
for i,(s,e,p) in enumerate(w,1):
77+
p="root" if h[i]==0 else "dep" if p=="root" else p
78+
u+="\t".join(
79+
[str(i),r[i-1],"_",z[s][0][2:],"_","|".join(z[s][1:]),str(h[i]),p,"_","_" if i<n and e<w[i][0] else "SpaceAfter=No"]
80+
)+"\n"
81+
return u+"\n"

tests/test_parse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
class TestParsePackage(unittest.TestCase):
88
def test_dependency_parsing(self):
99
self.assertIsNotNone(dependency_parsing("ผมเป็นคนดี", engine="esupar"))
10+
self.assertIsNotNone(dependency_parsing("ผมเป็นคนดี", engine="transformers_ud"))
1011
self.assertIsNotNone(dependency_parsing("ผมเป็นคนดี", engine="spacy_thai"))

0 commit comments

Comments
 (0)