[utils] Support list of xpath in xpath_element

This commit is contained in:
Sergey M․ 2015-10-31 22:39:44 +06:00
parent 8cdb5c8453
commit 578c074575
2 changed files with 19 additions and 3 deletions

View file

@ -275,9 +275,16 @@ class TestUtil(unittest.TestCase):
p = xml.etree.ElementTree.SubElement(div, 'p')
p.text = 'Foo'
self.assertEqual(xpath_element(doc, 'div/p'), p)
self.assertEqual(xpath_element(doc, ['div/p']), p)
self.assertEqual(xpath_element(doc, ['div/bar', 'div/p']), p)
self.assertEqual(xpath_element(doc, 'div/bar', default='default'), 'default')
self.assertEqual(xpath_element(doc, ['div/bar'], default='default'), 'default')
self.assertTrue(xpath_element(doc, 'div/bar') is None)
self.assertTrue(xpath_element(doc, ['div/bar']) is None)
self.assertTrue(xpath_element(doc, ['div/bar'], 'div/baz') is None)
self.assertRaises(ExtractorError, xpath_element, doc, 'div/bar', fatal=True)
self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar'], fatal=True)
self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar', 'div/baz'], fatal=True)
def test_xpath_text(self):
testxml = '''<root>

View file

@ -178,10 +178,19 @@ def xpath_with_ns(path, ns_map):
def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
if sys.version_info < (2, 7): # Crazy 2.6
xpath = xpath.encode('ascii')
def _find_xpath(xpath):
if sys.version_info < (2, 7): # Crazy 2.6
xpath = xpath.encode('ascii')
return node.find(xpath)
if isinstance(xpath, (str, compat_str)):
n = _find_xpath(xpath)
else:
for xp in xpath:
n = _find_xpath(xp)
if n is not None:
break
n = node.find(xpath)
if n is None:
if default is not NO_DEFAULT:
return default