Cómo leer desde un conjunto de datos de alta IO en pytorch que crece de una época a otra

8

Utilizo Tensorflow, pero estoy escribiendo documentación para usuarios que generalmente varía según los marcos de aprendizaje profundo .

Cuando trabajo con conjuntos de datos que no se ajustan al sistema de archivos local (TB +), tomo muestras de datos de un almacén de datos remoto y escribo muestras localmente en un tfrecordsformato estándar de Tensorflow .

Durante la primera época de entrenamiento solo habré muestreado algunos valores, por lo tanto, una época de datos locales es muy pequeña, entreno con ella. En la época 2 vuelvo a examinar qué archivos de datos han sido producidos por mis subprocesos de muestreo (ahora más) y me entreno en el conjunto ampliado de archivos de datos locales para la próxima época. Repita el proceso cada época. De esta manera, construyo un caché local de muestras y puedo expulsar muestras antiguas a medida que lleno el almacenamiento local. El caché de muestras locales crece aproximadamente en el momento en que el modelo necesita más la varianza (hacia la última parte del entrenamiento).

En Python / Tensorflow es crucial que no deserialice los datos en el proceso de bucle de entrenamiento de Python porque Python GIL no puede admitir las velocidades de transferencia de datos (300-600 MB / seg, los datos son científicos sin comprimir) y, por lo tanto, el rendimiento de la GPU sufre cuando Python GIL no puede atender el ciclo de entrenamiento rápido.

Escribir las muestras en un tfrecordsarchivo desde subprocesos (multiprocesamiento de python) permite que los nativos de tensorflow TFRecordsDatasetrealicen deserialización fuera de Python y, por lo tanto, evitamos los problemas de Python GIL, y puedo saturar una GPU con altas tasas de datos de E / S.

Me gustaría saber cómo abordaría este problema en Pytorch. Estoy escribiendo sobre la estrategia de muestreo que se está utilizando y quiero proporcionar recomendaciones específicas a los usuarios de Tensorflow y PyTorch, pero no conozco el ecosistema de preprocesamiento de PyTorch lo suficientemente bien como para escribir con suficiente detalle.

Nota al margen: la única solución puramente basada en Python para admitir estas velocidades de transferencia de datos puede venir en Python 3.8 con memoria compartida y multiprocesamiento del Sistema V, pero aún no lo he intentado ya que el soporte no es suficiente (pronto será ) Las soluciones de multiprocesamiento existentes no son suficientes porque requieren deserialización en el proceso del ciclo de capacitación y, por lo tanto, bloquean el GIL durante la deserialización a altas tasas de E / S.

David Parks
fuente
2
¿Cómo sabes que las tasas de transferencia de datos sufren de Python GIL? Que yo sepa, es la operación vinculada a la CPU la que se ve afectada por GIL en la mayoría de los casos, no la operación vinculada a E / S.
bombas
En mis pruebas, solo haciendo la deserialización entre los procesos de Python a las velocidades de datos más rápidas que puedo lograr mantiene el proceso objetivo al 100% de utilización de la CPU. He intentado muchos enfoques, asincio, multiprocesamiento, incluso lecturas directas de socket. En el caso de las lecturas directas de socket, puedo obtener 4 GB / seg en los procesos, y en el momento en que incluso trato de unir cadenas binarias, caigo a 2 GB / seg, y cualquier cosa más compleja me reduce a una velocidad máxima de transferencia de 1 GB / seg. Eso es todo con el proceso de destino utilizando completamente el núcleo y, por lo tanto, bloqueando el GIL.
David Parks
Tenga en cuenta que esto no es realmente un problema con grandes conjuntos de datos comunes como imagenet porque el IO necesario para mover archivos JPEG comprimidos en redes neuronales grandes es pequeño en comparación con lo que exige la formación de datos científicos sin comprimir en redes pequeñas.
David Parks
1
una unión de cadenas se clasifica en una operación vinculada a la CPU y puede exigir fácilmente una capacidad de CPU del 100% sin utilizar la capacidad de E / S de la máquina. Por lo tanto, no es una evidencia de que un GIL restrinja el rendimiento de E / S.
bombas
2
Esas operaciones triviales no reclaman el GIL del proceso principal si los datos se cargan DataLoadercomo en mi respuesta.
bombas

Respuestas:

7

En realidad, puede deserializar fácilmente los datos en un subproceso mediante el uso torch.utils.data.DataLoader. Al establecer el num_workersargumento en 1 o en un valor mayor, puede generar subprocesos con sus propios intérpretes de Python y GIL.

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
    for batch_idx, data in enumerate(loader):
         # loader in the main process does not claim GIL at this point

A Dataloaderrequiere un torch.utils.data.Datasetpara obtener datos. Puede que no sea un trabajo trivial implementar una subclase adecuada en su caso. En caso de que necesite recrear una Datasetinstancia para cada época, puede hacer algo como esto.

for epcoh in range(epochs):
    dset = get_new_dataset()
    loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
    for batch_idx, data in enumerate(loader):
        # Do training

o mejor

dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
    last_batch_idx =  (len(dset)-1) // loader.batch_size
    for batch_idx, data in enumerate(loader):
        # Prepare next loader in advance to avoid blocking
        if batch_idx == last_batch_idx:
            dset = get_new_dataset()
            loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
        # Do training

Como nota al margen, tenga en cuenta que la operación vinculada a la CPU se ve afectada por GIL en la mayoría de los casos, no la operación vinculada a E / S, es decir, funcionará threadingpara cualquier operación pesada de E / S y que ni siquiera necesita subprocess. Para obtener más información, consulte esta pregunta y este artículo de Wikipedia .

bombas
fuente
Solo para confirmar, ¿ torch.utils.data.DataLoadercoloca datos en la GPU de los subprocesos o está tratando de usar el multiprocesamiento de Python para moverlo al proceso de bucle de entrenamiento? He descubierto que solo la deserialización de un proceso a otro a velocidades de datos cercanas a 1 GB / seg es> 1 núcleo completo de trabajo, de ahí los problemas de GIL que he encontrado al intentar este enfoque en TF. Pero si torch.utils.data.DataLoaderestá moviendo datos a la GPU de una manera que no requiere la deserialización de Python, entonces todo está bien. Solo quiero confirmar esa comprensión.
David Parks
@DavidParks ¿Qué función específica utiliza cuando prueba la deserialización de un proceso a otro? Parece que el proceso de deserialización implica una operación vinculada a la CPU, de ahí los problemas de GIL.
bombas
He intentado multiprocesamiento (muy lento), Pipes (mejor) y lecturas de socket sin procesar (mejor). Ninguno de estos funciona cuando las tasas de E / S son una fracción sustancial de un GB / seg, solo mover esa cantidad de datos requiere más de 1 núcleo y, por lo tanto, las soluciones Python (antes de 3.8 y la memoria compartida del Sistema V) se desmoronan en Tensorflow. Es por eso que escribo en archivos tfrecords y dejo que tensorflow haga deserialización fuera de Python. TF no bloquea el Python GIL y paraleliza las operaciones, por lo que mi proceso principal utiliza un 600% de CPU, mientras que Python GIL permanece inactivo y libre para atender el ciclo de entrenamiento.
David Parks
@DavidParks Quiero decir, ¿qué tipo de función de deserialización o biblioteca utiliza? (no biblioteca de comunicación entre procesos). torch.utils.data.DataLoaderpuede utilizar fácilmente 600% de CPU o más, y el proceso principal no necesita mucha potencia de CPU en la mayoría de los casos cuando el entrenamiento es principalmente computación de GPU (cuando el entrenamiento es principalmente computación de CPU, todavía no hay problema porque la operación de matriz de Pytorch puede utilizar fácilmente múltiples CPUs).
bombas
Simplemente usando pickle para deserializar a través de procesos de python, luego una función de generador de python para alimentar muestras al ecosistema TensorFlow. Ese es el enfoque que me falla. Una vez que los datos están en el ecosistema TensorFlow, se colocan en la GPU y el entrenamiento es otra historia. TF no proporciona una forma para que los subprocesos de Python alimenten datos a TF, solo tiene unas pocas opciones, y los datos formateados tfrecords (formato Protocol Buffers) son los más lógicos. Parece que puede ser más fácil en PyTorch, por lo que tendré algunos usuarios de PyTorch aquí para validarlo.
David Parks