from xml.etree.ElementTree import ElementTree
from os import walk, path
def read_xml(in_path):
tree = ElementTree()
tree.parse(in_path)
return tree
def write_xml(tree, out_path):
tree.write(out_path, encoding="utf-8", xml_declaration=True)
def find_nodes(tree, path):
return tree.findall(path)
def del_node_by_target_classes(nodelist, target_classes_lower, tree_root):
for parent_node in nodelist:
children = parent_node.getchildren()
if (parent_node.tag == "object" and children[0].text.lower() not in target_classes_lower):
tree_root.remove(parent_node)
elif (parent_node.tag == "object" and children[0].text.lower() in target_classes_lower):
children[0].text = children[0].text.lower()
def get_fileNames(rootdir):
data_path = []
prefixs = []
for root, dirs, files in walk(rootdir, topdown=True):
for name in files:
pre, ending = path.splitext(name)
if ending != ".xml":
continue
else:
data_path.append(path.join(root, name))
prefixs.append(pre)
return data_path, prefixs
if __name__ == "__main__":
# get all the xml paths, prefixes if not used here
paths_xml, prefixs = get_fileNames("/home/yasin/old_labels/")
target_classes = ["PEOPLE", "CAT"] # target flags you want to keep
target_classes_lower = []
for i in range(len(target_classes)):
target_classes_lower.append(target_classes[i].lower()) # make sure your target is lowe-case
# print(target_classes_lower)
for i in range(len(paths_xml)):
# rename and save the corresponding xml
tree = read_xml(paths_xml[i])
# get tree node
tree_root = tree.getroot()
# get parent nodes
del_parent_nodes = find_nodes(tree, "./")
# get target classes and delete
target_del_node = del_node_by_target_classes(del_parent_nodes, target_classes_lower, tree_root)
# save output xml, 000001.xml
write_xml(tree, "/home/yasin/new_labels/{}.xml".format("%06d" % i))