web-dev-qa-db-fra.com

Mémorisation à Haskell?

Tous les conseils sur la façon de résoudre efficacement la fonction suivante dans Haskell, pour les grands nombres (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

J'ai vu des exemples de mémorisation dans Haskell pour résoudre les nombres de fibonacci, qui impliquaient de calculer (paresseusement) tous les nombres de fibonacci jusqu'au n requis. Mais dans ce cas, pour un n donné, il suffit de calculer très peu de résultats intermédiaires.

Merci

128
Angel de Vicente

Nous pouvons le faire très efficacement en créant une structure que nous pouvons indexer en temps sub-linéaire.

Mais d'abord,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Définissons f, mais faisons en sorte qu'il utilise la "récursion ouverte" plutôt que de s'appeler directement.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Vous pouvez obtenir un f non remémoré en utilisant fix f

Cela vous permettra de tester que f fait ce que vous voulez dire pour les petites valeurs de f en appelant, par exemple: fix f 123 = 144

Nous pourrions mémoriser cela en définissant:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Cela fonctionne passablement bien et remplace ce qui allait prendre O (n ^ 3) temps avec quelque chose qui mémorise les résultats intermédiaires.

Mais il faut toujours du temps linéaire pour simplement indexer pour trouver la réponse mémorisée pour mf. Cela signifie que des résultats comme:

*Main Data.List> faster_f 123801
248604

sont tolérables, mais le résultat n'évolue pas beaucoup mieux que cela. On peut faire mieux!

Définissons d'abord un arbre infini:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Et puis nous allons définir un moyen de l'indexer, afin que nous puissions trouver un nœud avec l'index n dans O (log n) temps à la place:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... et nous pouvons trouver un arbre plein de nombres naturels pour être pratique afin que nous n'ayons pas à jouer avec ces indices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Puisque nous pouvons indexer, vous pouvez simplement convertir un arbre en liste:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Vous pouvez vérifier le travail jusqu'à présent en vérifiant que toList nats vous donne [0..]

Maintenant,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

fonctionne comme avec la liste ci-dessus, mais au lieu de prendre du temps linéaire pour trouver chaque nœud, vous pouvez le poursuivre en temps logarithmique.

Le résultat est considérablement plus rapide:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

En fait, c'est tellement plus rapide que vous pouvez parcourir et remplacer Int par Integer ci-dessus et obtenir des réponses ridiculement grandes presque instantanément

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
246
Edward KMETT

la réponse d'Edward est un joyau si merveilleux que je l'ai dupliqué et fourni des implémentations de combinateurs memoList et memoTree qui mémorisent une fonction sous une forme récursive ouverte.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
17
Tom Ellis

Pas le moyen le plus efficace, mais mémorise:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

lors de la demande f !! 144, il est vérifié que f !! 143 existe, mais sa valeur exacte n'est pas calculée. Il est toujours défini comme un résultat inconnu d'un calcul. Les seules valeurs exactes calculées sont celles nécessaires.

Donc, initialement, en ce qui concerne le montant calculé, le programme ne sait rien.

f = .... 

Lorsque nous faisons la demande f !! 12, il commence à faire une correspondance de modèle:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, il commence à calculer

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Cela fait récursivement une autre demande sur f, donc nous calculons

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Maintenant, nous pouvons remonter un peu

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuer à ruisseler:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, nous continuons notre calcul de f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, nous continuons notre calcul de f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Le calcul se fait donc assez paresseusement. Le programme sait qu'une certaine valeur pour f !! 8 existe, qu'il est égal à g 8, mais il n'a aucune idée de ce que g 8 est.

12
rampion

Comme indiqué dans la réponse d'Edward Kmett, pour accélérer les choses, vous devez mettre en cache des calculs coûteux et pouvoir y accéder rapidement.

Pour garder la fonction non monadique, la solution de construction d'un arbre paresseux infini, avec une manière appropriée de l'indexer (comme indiqué dans les articles précédents) remplit cet objectif. Si vous abandonnez la nature non monadique de la fonction, vous pouvez utiliser les conteneurs associatifs standard disponibles dans Haskell en combinaison avec des monades "d'état" (comme State ou ST).

Alors que le principal inconvénient est que vous obtenez une fonction non monadique, vous n'avez plus besoin d'indexer la structure vous-même, et vous pouvez simplement utiliser des implémentations standard de conteneurs associatifs.

Pour ce faire, vous devez d'abord réécrire votre fonction pour accepter tout type de monade:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Pour vos tests, vous pouvez toujours définir une fonction qui ne fait aucune mémorisation en utilisant Data.Function.fix, bien qu'elle soit un peu plus verbeuse:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Vous pouvez ensuite utiliser State monad en combinaison avec Data.Map pour accélérer les choses:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Avec des modifications mineures, vous pouvez adapter le code pour qu'il fonctionne avec Data.HashMap à la place:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Au lieu de structures de données persistantes, vous pouvez également essayer des structures de données mutables (comme le Data.HashTable) en combinaison avec la monade ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

Comparé à l'implémentation sans aucune mémorisation, n'importe laquelle de ces implémentations vous permet, pour des entrées énormes, d'obtenir des résultats en micro-secondes au lieu d'avoir à attendre plusieurs secondes.

En utilisant Criterion comme référence, j'ai pu observer que l'implémentation avec Data.HashMap fonctionnait en fait légèrement mieux (environ 20%) que celle de Data.Map et Data.HashTable pour lesquelles les timings étaient très similaires.

J'ai trouvé les résultats du benchmark un peu surprenants. Mon sentiment initial était que le HashTable surpasserait l'implémentation de HashMap car il est mutable. Il peut y avoir un défaut de performance caché dans cette dernière implémentation.

8
Quentin

Ceci est un addendum à l'excellente réponse d'Edward Kmett.

Lorsque j'ai essayé son code, les définitions de nats et index semblaient assez mystérieuses, alors j'écris une version alternative que j'ai trouvé plus facile à comprendre.

Je définis index et nats en termes de index' et nats'.

index' t n est défini sur la plage [1..]. (Rappeler que index t est défini sur la plage [0..].) Il fonctionne dans l'arborescence en traitant n comme une chaîne de bits et en parcourant les bits à l'envers. Si le bit est 1, il prend la branche de droite. Si le bit est 0, il prend la branche de gauche. Il s'arrête lorsqu'il atteint le dernier bit (qui doit être un 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Tout comme nats est défini pour index de sorte que index nats n == n est toujours vrai, nats' est défini pour index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Maintenant, nats et index sont simplement nats' et index' mais avec des valeurs décalées de 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
8
Pitarou

Quelques années plus tard, j'ai regardé cela et j'ai réalisé qu'il y avait un moyen simple de le mémoriser en temps linéaire en utilisant zipWith et une fonction d'aide:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate a la propriété pratique que dilate n xs !! i == xs !! div i n.

Donc, en supposant qu'on nous donne f (0), cela simplifie le calcul pour

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Ressemblant beaucoup à notre description originale du problème et donnant une solution linéaire (sum $ take n fs prendra O (n)).

4
rampion

Encore un addendum à la réponse d'Edward Kmett: un exemple autonome:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Utilisez-le comme suit pour mémoriser une fonction avec un seul argument entier (par exemple fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Seules les valeurs des arguments non négatifs seront mises en cache.

Pour mettre également en cache les valeurs des arguments négatifs, utilisez memoInt, défini comme suit:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Pour mettre en cache les valeurs des fonctions avec deux arguments entiers, utilisez memoIntInt, défini comme suit:

memoIntInt f = memoInt (\n -> memoInt (f n))
2
Neal Young

Une solution sans indexation, et non basée sur celle d'Edward KMETT.

Je factorise les sous-arbres communs vers un parent commun (f(n/4) est partagé entre f(n/2) et f(n/4), et f(n/6) est partagé entre f(2) et f(3)). En les enregistrant comme une seule variable dans le parent, le calcul du sous-arbre est effectué une fois.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

Le code ne s'étend pas facilement à une fonction de mémorisation générale (au moins, je ne saurais pas comment le faire), et vous devez vraiment réfléchir à la façon dont les sous-problèmes se chevauchent, mais la stratégie devrait fonctionner pour les paramètres généraux multiples non entiers. (Je l'ai pensé pour deux paramètres de chaîne.)

Le mémo est supprimé après chaque calcul. (Encore une fois, je pensais à deux paramètres de chaîne.)

Je ne sais pas si c'est plus efficace que les autres réponses. Chaque recherche n'est techniquement qu'une ou deux étapes ("Regardez votre enfant ou l'enfant de votre enfant"), mais il peut y avoir beaucoup de mémoire supplémentaire.

Edit: Cette solution n'est pas encore correcte. Le partage est incomplet.

Edit: Il devrait partager correctement les sous-enfants maintenant, mais j'ai réalisé que ce problème a beaucoup de partage non trivial: n/2/2/2 et n/3/3 pourrait être le même. Le problème ne convient pas à ma stratégie.

2
leewz