forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlexical_analysis.py
More file actions
265 lines (238 loc) Β· 9.89 KB
/
Copy pathlexical_analysis.py
File metadata and controls
265 lines (238 loc) Β· 9.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# coding:utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from ..data import Pad, Stack, Tuple
from ..datasets import load_dataset
from .models import BiGruCrf
from .task import Task
from .utils import Customization
usage = r"""
from paddlenlp import Taskflow
lac = Taskflow("lexical_analysis")
lac("LACζ―δΈͺδΌη§ηεθ―ε·₯ε
·")
'''
[{'text': 'LACζ―δΈͺδΌη§ηεθ―ε·₯ε
·', 'segs': ['LAC', 'ζ―', 'δΈͺ', 'δΌη§', 'η', 'εθ―', 'ε·₯ε
·'], 'tags': ['nz', 'v', 'q', 'a', 'u', 'n', 'n']}]
'''
lac(["LACζ―δΈͺδΌη§ηεθ―ε·₯ε
·", "δΈδΊζ―δΈδΈͺηΎδΈ½ηεεΈ"])
'''
[{'text': 'LACζ―δΈͺδΌη§ηεθ―ε·₯ε
·', 'segs': ['LAC', 'ζ―', 'δΈͺ', 'δΌη§', 'η', 'εθ―', 'ε·₯ε
·'], 'tags': ['nz', 'v', 'q', 'a', 'u', 'n', 'n']},
{'text': 'δΈδΊζ―δΈδΈͺηΎδΈ½ηεεΈ', 'segs': ['δΈδΊ', 'ζ―', 'δΈδΈͺ', 'ηΎδΈ½', 'η', 'εεΈ'], 'tags': ['LOC', 'v', 'm', 'a', 'u', 'n']}
]
'''
"""
def load_vocab(dict_path):
"""
Load vocab from file
"""
vocab = {}
reverse = None
with open(dict_path, "r", encoding="utf8") as fin:
for i, line in enumerate(fin):
terms = line.strip("\n").split("\t")
if len(terms) == 2:
if reverse is None:
reverse = True if terms[0].isdigit() else False
if reverse:
value, key = terms
else:
key, value = terms
elif len(terms) == 1:
key, value = terms[0], str(i)
else:
raise ValueError("Error line: %s in file: %s" % (line, dict_path))
vocab[key] = value
return vocab
class LacTask(Task):
"""
Lexical analysis of Chinese task to segement the chinese sentence.
Args:
task(string): The name of task.
model(string): The model name in the task.
user_dict(string): The user-defined dictionary, default to None.
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
"""
resource_files_names = {
"model_state": "model_state.pdparams",
"tags": "tag.dic",
"q2b": "q2b.dic",
"word": "word.dic",
}
resource_files_urls = {
"lac": {
"model_state": [
"https://bj.bcebos.com/paddlenlp/taskflow/lexical_analysis/lac/model_state.pdparams",
"3d4008c6c9d29424465829c9acf909bd",
],
"tags": [
"https://bj.bcebos.com/paddlenlp/taskflow/lexical_analysis/lac/tag.dic",
"b11b616926b9f7f0a40a8087f84a8a99",
],
"q2b": [
"https://bj.bcebos.com/paddlenlp/taskflow/lexical_analysis/lac/q2b.dic",
"4ef2cd16f8002fe7cd7dd31cdff47e0d",
],
"word": [
"https://bj.bcebos.com/paddlenlp/taskflow/lexical_analysis/lac/word.dic",
"f1dfc68139bb6dd58c9c4313c341e436",
],
}
}
def __init__(self, task, model, user_dict=None, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._usage = usage
self._user_dict = user_dict
self._check_task_files()
self._construct_vocabs()
self._get_inference_model()
self._max_seq_len = 512
if self._user_dict:
self._custom = Customization()
self._custom.load_customization(self._user_dict)
else:
self._custom = None
def _construct_input_spec(self):
"""
Construct the input spec for the convert dygraph model to static model.
"""
self._input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_ids"),
paddle.static.InputSpec(shape=[None], dtype="int64", name="length"),
]
def _construct_vocabs(self):
word_dict_path = os.path.join(self._task_path, "word.dic")
tag_dict_path = os.path.join(self._task_path, "tag.dic")
q2b_dict_path = os.path.join(self._task_path, "q2b.dic")
self._word_vocab = load_vocab(word_dict_path)
self._tag_vocab = load_vocab(tag_dict_path)
self._q2b_vocab = load_vocab(q2b_dict_path)
self._id2word_dict = dict(zip(self._word_vocab.values(), self._word_vocab.keys()))
self._id2tag_dict = dict(zip(self._tag_vocab.values(), self._tag_vocab.keys()))
def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
model_instance = BiGruCrf(
self.kwargs["emb_dim"], self.kwargs["hidden_size"], len(self._word_vocab), len(self._tag_vocab)
)
# Load the model parameter for the predict
state_dict = paddle.load(os.path.join(self._task_path, "model_state.pdparams"))
model_instance.set_dict(state_dict)
self._model = model_instance
self._model.eval()
def _construct_tokenizer(self, model):
"""
Construct the tokenizer for the predictor.
"""
return None
def _preprocess(self, inputs, padding=True, add_special_tokens=True):
"""
Transform the raw text to the model inputs, two steps involved:
1) Transform the raw text to token ids.
2) Generate the other model inputs from the raw text and token ids.
"""
inputs = self._check_input_text(inputs)
# Get the config from the kwargs
batch_size = self.kwargs["batch_size"] if "batch_size" in self.kwargs else 1
num_workers = self.kwargs["num_workers"] if "num_workers" in self.kwargs else 0
self._split_sentence = self.kwargs["split_sentence"] if "split_sentence" in self.kwargs else False
oov_token_id = self._word_vocab.get("OOV")
filter_inputs = []
for input in inputs:
if not (isinstance(input, str) and len(input.strip()) > 0):
continue
filter_inputs.append(input)
short_input_texts, self.input_mapping = self._auto_splitter(
filter_inputs, self._max_seq_len, split_sentence=self._split_sentence
)
def read(inputs):
for input_tokens in inputs:
ids = []
for token in input_tokens:
token = self._q2b_vocab.get(token, token)
token_id = self._word_vocab.get(token, oov_token_id)
ids.append(token_id)
lens = len(ids)
yield ids, lens
infer_ds = load_dataset(read, inputs=short_input_texts, lazy=False)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=0, dtype="int64"), # input_ids
Stack(dtype="int64"), # seq_len
): fn(samples)
infer_data_loader = paddle.io.DataLoader(
infer_ds,
collate_fn=batchify_fn,
num_workers=num_workers,
batch_size=batch_size,
shuffle=False,
return_list=True,
)
outputs = {}
outputs["text"] = short_input_texts
outputs["data_loader"] = infer_data_loader
return outputs
def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_tokenize` function.
"""
results = []
lens = []
for batch in inputs["data_loader"]:
input_ids, seq_len = batch
self.input_handles[0].copy_from_cpu(input_ids.numpy())
self.input_handles[1].copy_from_cpu(seq_len.numpy())
self.predictor.run()
tags_ids = self.output_handle[0].copy_to_cpu()
results.extend(tags_ids.tolist())
lens.extend(seq_len.tolist())
inputs["result"] = results
inputs["lens"] = lens
return inputs
def _postprocess(self, inputs):
"""
The model output is the tag ids, this function will convert the model output to raw text.
"""
lengths = inputs["lens"]
preds = inputs["result"]
sents = inputs["text"]
final_results = []
for sent_index in range(len(lengths)):
single_result = {}
tags = [self._id2tag_dict[str(index)] for index in preds[sent_index][: lengths[sent_index]]]
sent = sents[sent_index]
if self._custom:
self._custom.parse_customization(sent, tags)
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split("-")[0])
continue
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split("-")[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
single_result["text"] = sent
single_result["segs"] = sent_out
single_result["tags"] = tags_out
final_results.append(single_result)
final_results = self._auto_joiner(final_results, self.input_mapping, is_dict=True)
return final_results