tests: gadget0: Added testing for the BOS descriptor retrieval and invalid BOS request handling

This commit is contained in:
dragonmux
2022-08-16 18:28:12 +01:00
committed by Piotr Esden-Tempski
parent 274958572a
commit 2f1d5e38ab

View File

@@ -13,6 +13,7 @@ import array
import datetime import datetime
import random import random
import usb.core import usb.core
import usb.control
import usb.util as uu import usb.util as uu
import random import random
import sys import sys
@@ -41,6 +42,13 @@ GZ_REQ_READ_LOOPBACK_BUFFER=11
GZ_REQ_INTEL_WRITE=0x5b GZ_REQ_INTEL_WRITE=0x5b
GZ_REQ_INTEL_READ=0x5c GZ_REQ_INTEL_READ=0x5c
DESC_TYPE_BOS = 0x0F
DESC_TYPE_DEVICE_CAPABILITY = 0x10
DEVCAP_TYPE_PLATFORM = 5
MICROSOFT_WINDOWS_VERSION_WINBLUE = 0x06030000
class find_by_serial(object): class find_by_serial(object):
def __init__(self, serial): def __init__(self, serial):
self._serial = serial self._serial = serial
@@ -260,7 +268,6 @@ class TestConfigLoopBack(unittest.TestCase):
expected = array.array('B', [x for x in data]) expected = array.array('B', [x for x in data])
self.assertEqual(expected, read, "should have read back what we wrote") self.assertEqual(expected, read, "should have read back what we wrote")
def test_simple_loop(self): def test_simple_loop(self):
"""Plain simple loopback, does it work at all""" """Plain simple loopback, does it work at all"""
eout = self.eps_out[0] eout = self.eps_out[0]
@@ -506,6 +513,68 @@ class TestUnaligned(unittest.TestCase):
self.do_readwrite() self.do_readwrite()
class TestBOSDescriptor(unittest.TestCase):
"""
Make sure the stack correctly handles a request for the BOS descriptor, and discards invalid BOS requests
"""
def setUp(self):
self.dev : usb.core.Device = usb.core.find(idVendor=VENDOR_ID, idProduct=PRODUCT_ID, custom_match=find_by_serial(DUT_SERIAL))
self.assertIsNotNone(self.dev, "Couldn't find locm3 gadget0 device")
def tearDown(self):
uu.dispose_resources(self.dev)
def test_partial_request(self):
bos : bytes = usb.control.get_descriptor(self.dev, 5, DESC_TYPE_BOS, 0).tobytes()
self.assertEqual(len(bos), 5)
# Check the BOS descriptor returned is valid
self.assertEqual(bos[0], 5)
self.assertEqual(bos[1], DESC_TYPE_BOS)
self.assertEqual(bos[4], 1)
bos_total_length = bos[2] + (bos[3] << 8)
self.assertNotEqual(bos_total_length, 0)
self.assertEqual(bos_total_length, 33)
def test_complete_request(self):
bos : bytes = usb.control.get_descriptor(self.dev, 33, DESC_TYPE_BOS, 0).tobytes()
self.assertEqual(len(bos), 33)
# Check the BOS descriptor returned is valid
self.assertEqual(bos[0], 5)
self.assertEqual(bos[1], DESC_TYPE_BOS)
self.assertEqual(bos[4], 1)
# This BOS descriptor is followed by a platform capability descriptor
platform_capability = bos[5:]
self.assertEqual(platform_capability[0], 28)
self.assertEqual(platform_capability[1], DESC_TYPE_DEVICE_CAPABILITY)
self.assertEqual(platform_capability[2], DEVCAP_TYPE_PLATFORM)
self.assertEqual(platform_capability[3], 0)
# Make sure it's a Microsoft OS 2.0 Descriptors platform capability
self.assertEqual(platform_capability[4:20].hex(), 'df60ddd88945c74c9cd2659d9e648a9f')
# The PCD is followed by a Microsoft OS Descriptor Set Info descriptor
descriptor_set_info = platform_capability[20:]
windows_version = (descriptor_set_info[0] | (descriptor_set_info[1] << 8) |
(descriptor_set_info[2] << 16) | (descriptor_set_info[3] << 24))
self.assertEqual(windows_version, MICROSOFT_WINDOWS_VERSION_WINBLUE)
total_length = descriptor_set_info[4] | (descriptor_set_info[5] << 8)
self.assertNotEqual(total_length, 0)
self.assertEqual(descriptor_set_info[6], 1)
self.assertEqual(descriptor_set_info[7], 0)
def test_invalid_request(self):
try:
usb.control.get_descriptor(self.dev, 5, DESC_TYPE_BOS, 1)
self.fail("get_descriptor() for an invalid BOS request suceeded")
except usb.core.USBError as e:
# Make sure we got a pipe error (EP0 STALL)
# Why libusb returns stalls as this and makes them fatal is.. a subject for a different place..
self.assertEqual(e.errno, 32)
def run_ci_test(dut): def run_ci_test(dut):
# Avoids the import for non-CI users! # Avoids the import for non-CI users!
import xmlrunner import xmlrunner
@@ -542,4 +611,3 @@ if __name__ == "__main__":
print("Detected %s on bus:port-address: %s:%s-%s" % (DUT_SERIAL, dev.bus, '.'.join(map(str,dev.port_numbers)), dev.address)) print("Detected %s on bus:port-address: %s:%s-%s" % (DUT_SERIAL, dev.bus, '.'.join(map(str,dev.port_numbers)), dev.address))
else: else:
runner(DUT_SERIAL) runner(DUT_SERIAL)