## Recursive Matrix Mulitplication in Haskell ( First step to Strassen’s Algorithm )

Strassen’s algorithm states that we can multiply two 2 X 2 matrices in seven multiplication rather than eight. I wrote a quick recursive solution using eight multiplication and it’s not very efficient because the real problem is Haskell’s List laziness. The best library could be repa library or dph and I will try to implement it using these libraries. This implementation assumes that you are providing the matrix which is already power of 2.Please let me know if you found any bug or suggestions to improve this code.

import Data.List data Matrix a = Matrix { get :: [[a]] } deriving ( Show ) instance ( Num a ) => Num ( Matrix a ) where ( Matrix xs ) + ( Matrix ys ) = Matrix ( zipWith ( \ x y -> zipWith ( + ) x y ) xs ys ) ( Matrix xs ) - ( Matrix ys ) = Matrix ( zipWith ( \ x y -> zipWith ( - ) x y ) xs ys ) ( Matrix xs ) * ( Matrix ys ) = Matrix ( map ( \x -> map ( sum.zipWith (*) x ) ( transpose ys ) ) xs ) abs ( Matrix xs ) = undefined signum ( Matrix xs ) = undefined fromInteger _ = undefined recurMult :: ( Num a ) => Int -> Int -> Matrix a -> Matrix a -> Matrix a recurMult n lev xs ys | lev >= 2 = xs * ys -- not splitting matrix more than 2 levels | otherwise = Matrix ret where n' = div n 2 ( a , b ) = ( get xs , get ys ) ( a_u , a_l ) = splitAt n' a ( b_u , b_l ) = splitAt n' b ( a11 , a12 ) = ( Matrix { get = map ( fst . splitAt n' ) a_u } , Matrix { get = map ( snd . splitAt n' ) a_u } ) ( a21 , a22 ) = ( Matrix { get = map ( fst . splitAt n' ) a_l } , Matrix { get = map ( snd . splitAt n' ) a_l } ) ( b11 , b12 ) = ( Matrix { get = map ( fst . splitAt n' ) b_u } , Matrix { get = map ( snd . splitAt n' ) b_u } ) ( b21 , b22 ) = ( Matrix { get = map ( fst . splitAt n' ) b_l } , Matrix { get = map ( snd . splitAt n' ) b_l } ) Matrix c11 = recurMult n' ( lev + 1 ) a11 b11 + recurMult n' ( lev + 1 ) a12 b21 Matrix c12 = recurMult n' ( lev + 1 ) a11 b12 + recurMult n' ( lev + 1 ) a12 b22 Matrix c21 = recurMult n' ( lev + 1 ) a21 b11 + recurMult n' ( lev + 1 ) a22 b21 Matrix c22 = recurMult n' ( lev + 1 ) a21 b12 + recurMult n' ( lev + 1 ) a22 b22 ret = ( zipWith ( ++ ) c11 c12 ) ++ ( zipWith ( ++ ) c21 c22 ) tempMult :: ( Num a ) => [ [ a ] ] -> [ [ a ] ] -> [ [ a ] ] tempMult xs ys = get $ recurMult ( length xs ) 0 ( Matrix xs ) ( Matrix ys ) *Main> tempMult ( [ [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] ] ) ( [ [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] ] ) [[10,20,30,40],[10,20,30,40],[10,20,30,40],[10,20,30,40]]

Advertisements

No comments yet.

## Leave a Reply