You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
8.5 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import os.path
import logging
logging.basicConfig(level=logging.INFO)
support_list = {
"it": "italian",
"xi": "spanish",
"pu": "portuguese",
"ru": "russian",
"ar": "arabic",
"ta": "tamil",
"ug": "uyghur",
"fa": "persian",
"ur": "urdu",
"rs": "serbian latin",
"oc": "occitan",
"rsc": "serbian cyrillic",
"bg": "bulgarian",
"uk": "ukranian",
"be": "belarusian",
"te": "telugu",
"ka": "kannada",
"chinese_cht": "chinese tradition",
"hi": "hindi",
"mr": "marathi",
"ne": "nepali",
}
latin_lang = [
"af",
"az",
"bs",
"cs",
"cy",
"da",
"de",
"es",
"et",
"fr",
"ga",
"hr",
"hu",
"id",
"is",
"it",
"ku",
"la",
"lt",
"lv",
"mi",
"ms",
"mt",
"nl",
"no",
"oc",
"pi",
"pl",
"pt",
"ro",
"rs_latin",
"sk",
"sl",
"sq",
"sv",
"sw",
"tl",
"tr",
"uz",
"vi",
"latin",
]
arabic_lang = ["ar", "fa", "ug", "ur"]
cyrillic_lang = [
"ru",
"rs_cyrillic",
"be",
"bg",
"uk",
"mn",
"abq",
"ady",
"kbd",
"ava",
"dar",
"inh",
"che",
"lbe",
"lez",
"tab",
"cyrillic",
]
devanagari_lang = [
"hi",
"mr",
"ne",
"bh",
"mai",
"ang",
"bho",
"mah",
"sck",
"new",
"gom",
"sa",
"bgc",
"devanagari",
]
multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang
assert os.path.isfile(
"./rec_multi_language_lite_train.yml"
), "Loss basic configuration file rec_multi_language_lite_train.yml.\
You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
global_config = yaml.load(
open("./rec_multi_language_lite_train.yml", "rb"), Loader=yaml.Loader
)
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
self.add_argument(
"-l",
"--language",
nargs="+",
help="set language type, support {}".format(support_list),
)
self.add_argument(
"--train",
type=str,
help="you can use this command to change the train dataset default path",
)
self.add_argument(
"--val",
type=str,
help="you can use this command to change the eval dataset default path",
)
self.add_argument(
"--dict",
type=str,
help="you can use this command to change the dictionary default path",
)
self.add_argument(
"--data_dir",
type=str,
help="you can use this command to change the dataset default root path",
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.opt = self._parse_opt(args.opt)
args.language = self._set_language(args.language)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split("=")
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
def _set_language(self, type):
lang = type[0]
assert type, "please use -l or --language to choose language type"
assert lang in support_list.keys() or lang in multi_lang, (
"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, "
"please check your running command".format(multi_lang, type)
)
if lang in latin_lang:
lang = "latin"
elif lang in arabic_lang:
lang = "arabic"
elif lang in cyrillic_lang:
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
global_config["Global"]["character_dict_path"] = (
"ppocr/utils/dict/{}_dict.txt".format(lang)
)
global_config["Global"]["save_model_dir"] = "./json/rec_{}_lite".format(lang)
global_config["Train"]["dataset"]["label_file_list"] = [
"train_data/{}_train.txt".format(lang)
]
global_config["Eval"]["dataset"]["label_file_list"] = [
"train_data/{}_val.txt".format(lang)
]
global_config["Global"]["character_type"] = lang
assert os.path.isfile(
os.path.join(project_path, global_config["Global"]["character_dict_path"])
), "Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
lang
)
return lang
def merge_config(config):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in config.items():
if "." not in key:
if isinstance(value, dict) and key in global_config:
global_config[key].update(value)
else:
global_config[key] = value
else:
sub_keys = key.split(".")
assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0]
)
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
def loss_file(path):
assert os.path.exists(
path
), "There is no such file:{},Please do not forget to put in the specified file".format(
path
)
if __name__ == "__main__":
FLAGS = ArgsParser().parse_args()
merge_config(FLAGS.opt)
save_file_path = "rec_{}_lite_train.yml".format(FLAGS.language)
if os.path.isfile(save_file_path):
os.remove(save_file_path)
if FLAGS.train:
global_config["Train"]["dataset"]["label_file_list"] = [FLAGS.train]
train_label_path = os.path.join(project_path, FLAGS.train)
loss_file(train_label_path)
if FLAGS.val:
global_config["Eval"]["dataset"]["label_file_list"] = [FLAGS.val]
eval_label_path = os.path.join(project_path, FLAGS.val)
loss_file(eval_label_path)
if FLAGS.dict:
global_config["Global"]["character_dict_path"] = FLAGS.dict
dict_path = os.path.join(project_path, FLAGS.dict)
loss_file(dict_path)
if FLAGS.data_dir:
global_config["Eval"]["dataset"]["data_dir"] = FLAGS.data_dir
global_config["Train"]["dataset"]["data_dir"] = FLAGS.data_dir
data_dir = os.path.join(project_path, FLAGS.data_dir)
loss_file(data_dir)
with open(save_file_path, "w") as f:
yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
logging.info("Project path is :{}".format(project_path))
logging.info(
"Train list path set to :{}".format(
global_config["Train"]["dataset"]["label_file_list"][0]
)
)
logging.info(
"Eval list path set to :{}".format(
global_config["Eval"]["dataset"]["label_file_list"][0]
)
)
logging.info(
"Dataset root path set to :{}".format(
global_config["Eval"]["dataset"]["data_dir"]
)
)
logging.info(
"Dict path set to :{}".format(
global_config["Global"]["character_dict_path"]
)
)
logging.info(
"Config file set to :configs/rec/multi_language/{}".format(save_file_path)
)