File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
packages: .
12
ignore-project: False
23
write-ghc-environment-files: always
34
tests: True
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
module ArrayFire.Data where
3131

3232
import Control.Exception
33-
import Control.Monad
3433
import Data.Complex
3534
import Data.Int
3635
import Data.Proxy
3736
import Data.Word
3837
import Foreign.C.Types
3938
import Foreign.ForeignPtr
4039
import Foreign.Marshal hiding (void)
40+
import Foreign.Ptr (Ptr)
4141
import Foreign.Storable
4242
import System.IO.Unsafe
4343
import Unsafe.Coerce
@@ -357,20 +357,21 @@ joinMany
357357
:: Int
358358
-> [Array a]
359359
-> Array a
360-
joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do
361-
fptrs <- forM arrays $ \(Array fptr) -> pure fptr
362-
newPtr <-
363-
alloca $ \fPtrsPtr -> do
364-
forM_ fptrs $ \fptr ->
365-
withForeignPtr fptr (poke fPtrsPtr)
366-
alloca $ \aPtr -> do
367-
zeroOutArray aPtr
368-
throwAFError =<< af_join_many aPtr n nArrays fPtrsPtr
369-
peek aPtr
360+
joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do
361+
newPtr <- alloca $ \aPtr -> do
362+
zeroOutArray aPtr
363+
(throwAFError =<<) $
364+
withManyForeignPtr arrays $ \(fromIntegral -> nArrays) fPtrsPtr ->
365+
af_join_many aPtr n nArrays fPtrsPtr
366+
peek aPtr
370367
Array <$>
371368
newForeignPtr af_release_array_finalizer newPtr
369+
370+
withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b
371+
withManyForeignPtr fptrs action = go [] fptrs
372372
where
373-
nArrays = fromIntegral (length arrays)
373+
go ptrs [] = withArrayLen (reverse ptrs) action
374+
go ptrs (fptr:others) = withForeignPtr fptr $ \ptr -> go (ptr : ptrs) others
374375

375376
-- | Tiles an Array according to specified dimensions
376377
--
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,8 @@ spec =
3232
constant @(Complex Float) [1] (1.0 :+ 1.0)
3333
`shouldBe`
3434
constant @(Complex Float) [1] (1.0 :+ 1.0)
35+
it "Should join Arrays along the specified dimension" $ do
36+
join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
37+
join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2]
38+
joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
39+
joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3]

0 commit comments

Comments
 (0)