¿Memoración en Haskell?

136

Cualquier indicador sobre cómo resolver eficientemente la siguiente función en Haskell, para grandes números (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

He visto ejemplos de memorización en Haskell para resolver números de Fibonacci, lo que implicaba calcular (perezosamente) todos los números de Fibonacci hasta el n requerido. Pero en este caso, para un n dado, solo necesitamos calcular muy pocos resultados intermedios.

Gracias

Angel de Vicente
fuente
110
Solo en el sentido de que es un trabajo que estoy haciendo en casa :-)
Angel de Vicente

Respuestas:

256

Podemos hacer esto de manera muy eficiente haciendo una estructura que podamos indexar en tiempo sub-lineal.

Pero primero,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Definamos f, pero hagamos que use 'recursión abierta' en lugar de llamarse a sí mismo directamente.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Puede obtener un no conmemorado fusandofix f

Esto le permitirá probar que fhace lo que quiere decir con valores pequeños fllamando, por ejemplo:fix f 123 = 144

Podríamos memorizar esto definiendo:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Eso funciona aceptablemente bien y reemplaza lo que iba a tomar tiempo O (n ^ 3) con algo que memoriza los resultados intermedios.

Pero todavía lleva tiempo lineal solo indexar para encontrar la respuesta memorable mf. Esto significa que resultados como:

*Main Data.List> faster_f 123801
248604

son tolerables, pero el resultado no escala mucho mejor que eso. ¡Podemos hacerlo mejor!

Primero, definamos un árbol infinito:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Y luego definiremos una forma de indexarlo, para que podamos encontrar un nodo con índice nen tiempo O (log n) en su lugar:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... y podemos encontrar un árbol lleno de números naturales que sea conveniente para no tener que jugar con esos índices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Como podemos indexar, puede convertir un árbol en una lista:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Puede verificar el trabajo hasta ahora verificando que toList natsle da[0..]

Ahora,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

funciona igual que con la lista anterior, pero en lugar de tomar tiempo lineal para encontrar cada nodo, puede perseguirlo en tiempo logarítmico.

El resultado es considerablemente más rápido:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

De hecho, es mucho más rápido que puede pasar y reemplazar Intcon el Integeranterior y obtener respuestas ridículamente grandes casi instantáneamente

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Edward KMETT
fuente
3
Probé este código y, curiosamente, f_faster parecía ser más lento que f. Supongo que esas referencias de la lista realmente retrasaron las cosas. La definición de nats e índice me pareció bastante misteriosa, así que agregué mi propia respuesta que podría aclarar las cosas.
Pitarou
55
El caso de la lista infinita tiene que tratar con una lista vinculada de 111111111 elementos de largo. El caso del árbol trata con log n * el número de nodos alcanzados.
Edward KMETT
2
es decir, la versión de la lista debe crear thunks para todos los nodos de la lista, mientras que la versión del árbol evita crear muchos de ellos.
Tom Ellis
77
Sé que esta es una publicación bastante antigua, pero no f_treedebería definirse en una wherecláusula para evitar guardar rutas innecesarias en el árbol a través de las llamadas.
dfeuer
17
La razón para rellenarlo en un CAF fue que podía obtener una memorización entre llamadas. Si tuviera una llamada costosa que estaba recordando, entonces probablemente la dejaría en un CAF, de ahí la técnica que se muestra aquí. En una aplicación real, existe una compensación entre los beneficios y los costos de la memorización permanente, por supuesto. Sin embargo, dado que la pregunta era sobre cómo lograr la memorización, creo que sería engañoso responder con una técnica que evite deliberadamente la memorización a través de llamadas, y si nada más, este comentario aquí señalará a la gente el hecho de que hay sutilezas. ;)
Edward KMETT
17

La respuesta de Edward es una joya tan maravillosa que la he duplicado y proporcioné implementaciones memoListy memoTreecombinadores que memorizan una función en forma recursiva abierta.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
Tom Ellis
fuente
12

No es la forma más eficiente, pero recuerda:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

cuando se solicita f !! 144, se verifica que f !! 143existe, pero no se calcula su valor exacto. Todavía se establece como un resultado desconocido de un cálculo. Los únicos valores exactos calculados son los necesarios.

Inicialmente, en cuanto a cuánto se ha calculado, el programa no sabe nada.

f = .... 

Cuando hacemos la solicitud f !! 12, comienza a hacer alguna coincidencia de patrones:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora comienza a calcular

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Esto recursivamente hace otra demanda en f, entonces calculamos

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Ahora podemos hacer una copia de seguridad de algunos

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Lo que significa que el programa ahora sabe:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuando a gotear:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Lo que significa que el programa ahora sabe:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora continuamos con nuestro cálculo de f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Lo que significa que el programa ahora sabe:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Ahora continuamos con nuestro cálculo de f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Lo que significa que el programa ahora sabe:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Por lo tanto, el cálculo se realiza con bastante pereza. El programa sabe que f !! 8existe algún valor para , que es igual a g 8, pero no tiene idea de qué g 8es.

rampion
fuente
Gracias por esto. ¿Cómo crearía y usaría un espacio de solución bidimensional? ¿Sería una lista de listas? yg n m = (something with) f!!a!!b
vikingsteve
1
Claro que podrías. Sin embargo, para una solución real, probablemente usaría una biblioteca de memorización, como memocombinators
rampion
Es O (n ^ 2) por desgracia.
Qumeric el
8

Esta es una adición a la excelente respuesta de Edward Kmett.

Cuando probé su código, las definiciones de natsy indexparecían bastante misteriosas, así que escribí una versión alternativa que me resultó más fácil de entender.

Defino indexy natsen términos de index'y nats'.

index' t nse define sobre el rango [1..]. (Recuerde que index tse define sobre el rango [0..]). Funciona busca en el árbol al tratarlo ncomo una cadena de bits y leer los bits a la inversa. Si el bit es 1, toma la rama de la derecha. Si el bit es 0, toma la rama de la izquierda. Se detiene cuando alcanza el último bit (que debe ser a 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Así como natsse define para, indexeso index nats n == nsiempre es cierto, nats'se define para index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Ahora, natsy indexson simples nats'y index'pero con los valores desplazados por 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Pitarou
fuente
Gracias. Estoy memorizando una función multivariante, y esto realmente me ayudó a determinar qué índices y nats realmente estaban haciendo.
Kittsil
8

Como se indica en la respuesta de Edward Kmett, para acelerar las cosas, debe almacenar en caché los costosos cálculos y poder acceder a ellos rápidamente.

Para mantener la función no monádica, la solución de construir un árbol perezoso infinito, con una forma adecuada de indexarlo (como se muestra en publicaciones anteriores) cumple ese objetivo. Si renuncia a la naturaleza no monádica de la función, puede usar los contenedores asociativos estándar disponibles en Haskell en combinación con mónadas "estatales" (como State o ST).

Si bien el inconveniente principal es que obtiene una función no monádica, ya no tiene que indexar la estructura usted mismo, y solo puede usar implementaciones estándar de contenedores asociativos.

Para hacerlo, primero debe volver a escribir su función para aceptar cualquier tipo de mónada:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Para sus pruebas, aún puede definir una función que no memorice usando Data.Function.fix, aunque es un poco más detallado:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Luego puede usar State Mónada en combinación con Data.Map para acelerar las cosas:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Con cambios menores, puede adaptar el código para que funcione con Data.HashMap en su lugar:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

En lugar de estructuras de datos persistentes, también puede probar estructuras de datos mutables (como Data.HashTable) en combinación con la mónada ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

En comparación con la implementación sin ninguna memorización, cualquiera de estas implementaciones le permite, para grandes entradas, obtener resultados en microsegundos en lugar de tener que esperar varios segundos.

Utilizando Criterion como punto de referencia, pude observar que la implementación con Data.HashMap en realidad funcionó un poco mejor (alrededor del 20%) que Data.Map y Data.HashTable para los cuales los tiempos fueron muy similares.

Los resultados del punto de referencia me parecieron un poco sorprendentes. Mi sensación inicial fue que HashTable superaría la implementación de HashMap porque es mutable. Puede haber algún defecto de rendimiento oculto en esta última implementación.

Quentin
fuente
2
GHC hace un muy buen trabajo optimizando alrededor de estructuras inmutables. La intuición de C no siempre funciona.
John Tyree
3

Un par de años después, miré esto y me di cuenta de que hay una manera simple de memorizar esto en tiempo lineal usando zipWithuna función auxiliar:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilatetiene la práctica propiedad que dilate n xs !! i == xs !! div i n.

Entonces, suponiendo que se nos dé f (0), esto simplifica el cálculo a

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Se parece mucho a nuestra descripción original del problema y da una solución lineal ( sum $ take n fstomará O (n)).

rampion
fuente
2
entonces es una solución generativa (corecursive?) o dinámica. Tomando O (1) tiempo por cada valor generado, como lo está haciendo Fibonacci. ¡Excelente! Y la solución de EKMETT es como el gran Fibonacci logarítmico, llegando a los números grandes mucho más rápido, saltando gran parte de los intermedios. ¿Es esto correcto?
Will Ness
o tal vez esté más cerca de los números de Hamming, con los tres indicadores de retroceso en la secuencia que se está produciendo, y las diferentes velocidades para cada uno de ellos avanzando a lo largo de la misma. realmente bonito.
Will Ness
2

Otra adición a la respuesta de Edward Kmett: un ejemplo autónomo:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Úselo de la siguiente manera para memorizar una función con un solo argumento entero (por ejemplo, fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Solo se almacenarán en caché los valores para argumentos no negativos.

Para almacenar también valores de caché para argumentos negativos, use memoInt, definido de la siguiente manera:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Para almacenar en caché los valores de las funciones con dos argumentos de uso entero memoIntInt, definidos de la siguiente manera:

memoIntInt f = memoInt (\n -> memoInt (f n))
Neal Young
fuente
2

Una solución sin indexación, y no basada en Edward KMETT.

Factorizo ​​subárboles comunes a un padre común ( f(n/4)se comparte entre f(n/2)y f(n/4), y f(n/6)se comparte entre f(2)y f(3)). Al guardarlos como una sola variable en el padre, el cálculo del subárbol se realiza una vez.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

El código no se extiende fácilmente a una función de memorización general (al menos, no sabría cómo hacerlo), y realmente tiene que pensar cómo se superponen los subproblemas, pero la estrategia debería funcionar para múltiples parámetros generales no enteros . (Lo pensé para dos parámetros de cadena).

El memo se descarta después de cada cálculo. (Nuevamente, estaba pensando en dos parámetros de cadena).

No sé si esto es más eficiente que las otras respuestas. Cada búsqueda es técnicamente solo uno o dos pasos ("Mire a su hijo o al hijo de su hijo"), pero puede haber mucho uso de memoria adicional.

Editar: esta solución aún no es correcta. El intercambio es incompleto.

Editar: Debería estar compartiendo los hijos secundarios correctamente ahora, pero me di cuenta de que este problema tiene una gran cantidad de compartir no trivial: n/2/2/2y n/3/3podría ser el mismo. El problema no encaja bien con mi estrategia.

leewz
fuente