Utilizo Tensorflow, pero estoy escribiendo documentación para usuarios que generalmente varía según los marcos de aprendizaje profundo .
Cuando trabajo con conjuntos de datos que no se ajustan al sistema de archivos local (TB +), tomo muestras de datos de un almacén de datos remoto y escribo muestras localmente en un tfrecords
formato estándar de Tensorflow .
Durante la primera época de entrenamiento solo habré muestreado algunos valores, por lo tanto, una época de datos locales es muy pequeña, entreno con ella. En la época 2 vuelvo a examinar qué archivos de datos han sido producidos por mis subprocesos de muestreo (ahora más) y me entreno en el conjunto ampliado de archivos de datos locales para la próxima época. Repita el proceso cada época. De esta manera, construyo un caché local de muestras y puedo expulsar muestras antiguas a medida que lleno el almacenamiento local. El caché de muestras locales crece aproximadamente en el momento en que el modelo necesita más la varianza (hacia la última parte del entrenamiento).
En Python / Tensorflow es crucial que no deserialice los datos en el proceso de bucle de entrenamiento de Python porque Python GIL no puede admitir las velocidades de transferencia de datos (300-600 MB / seg, los datos son científicos sin comprimir) y, por lo tanto, el rendimiento de la GPU sufre cuando Python GIL no puede atender el ciclo de entrenamiento rápido.
Escribir las muestras en un tfrecords
archivo desde subprocesos (multiprocesamiento de python) permite que los nativos de tensorflow TFRecordsDataset
realicen deserialización fuera de Python y, por lo tanto, evitamos los problemas de Python GIL, y puedo saturar una GPU con altas tasas de datos de E / S.
Me gustaría saber cómo abordaría este problema en Pytorch. Estoy escribiendo sobre la estrategia de muestreo que se está utilizando y quiero proporcionar recomendaciones específicas a los usuarios de Tensorflow y PyTorch, pero no conozco el ecosistema de preprocesamiento de PyTorch lo suficientemente bien como para escribir con suficiente detalle.
Nota al margen: la única solución puramente basada en Python para admitir estas velocidades de transferencia de datos puede venir en Python 3.8 con memoria compartida y multiprocesamiento del Sistema V, pero aún no lo he intentado ya que el soporte no es suficiente (pronto será ) Las soluciones de multiprocesamiento existentes no son suficientes porque requieren deserialización en el proceso del ciclo de capacitación y, por lo tanto, bloquean el GIL durante la deserialización a altas tasas de E / S.
DataLoader
como en mi respuesta.Respuestas:
En realidad, puede deserializar fácilmente los datos en un subproceso mediante el uso
torch.utils.data.DataLoader
. Al establecer elnum_workers
argumento en 1 o en un valor mayor, puede generar subprocesos con sus propios intérpretes de Python y GIL.A
Dataloader
requiere untorch.utils.data.Dataset
para obtener datos. Puede que no sea un trabajo trivial implementar una subclase adecuada en su caso. En caso de que necesite recrear unaDataset
instancia para cada época, puede hacer algo como esto.o mejor
Como nota al margen, tenga en cuenta que la operación vinculada a la CPU se ve afectada por GIL en la mayoría de los casos, no la operación vinculada a E / S, es decir, funcionará
threading
para cualquier operación pesada de E / S y que ni siquiera necesitasubprocess
. Para obtener más información, consulte esta pregunta y este artículo de Wikipedia .fuente
torch.utils.data.DataLoader
coloca datos en la GPU de los subprocesos o está tratando de usar el multiprocesamiento de Python para moverlo al proceso de bucle de entrenamiento? He descubierto que solo la deserialización de un proceso a otro a velocidades de datos cercanas a 1 GB / seg es> 1 núcleo completo de trabajo, de ahí los problemas de GIL que he encontrado al intentar este enfoque en TF. Pero sitorch.utils.data.DataLoader
está moviendo datos a la GPU de una manera que no requiere la deserialización de Python, entonces todo está bien. Solo quiero confirmar esa comprensión.torch.utils.data.DataLoader
puede utilizar fácilmente 600% de CPU o más, y el proceso principal no necesita mucha potencia de CPU en la mayoría de los casos cuando el entrenamiento es principalmente computación de GPU (cuando el entrenamiento es principalmente computación de CPU, todavía no hay problema porque la operación de matriz de Pytorch puede utilizar fácilmente múltiples CPUs).