Vocabulary
tensors

Inputs
tensor1a tensor
tensor2a tensor


Outputs
tensor3a tensor


Word description
Performs n-dimensional matrix multiplication on two tensors, where tensor1 has shape ...xmxn and tensor1 has shape ...xnxp.

Errors
Throws a shape-mismatch-error if the bottom two dimensions of the tensors passed do not take the form mxn and nxp and/or the top dimensions do not match.

Definition


TYPED:: matmul
( tensor1: tensor tensor2: tensor -- tensor3: tensor )
tensor1 tensor2 check-matmul-shape
tensor1 shape>> unclip-last-slice :> n unclip-last-slice
:> m :> top-shape tensor2 shape>> last :> p top-shape
product :> top-prod top-shape { m p } append
dup product (float-array) :> vec3 top-prod [
:> i tensor1 vec>> m n * i * m n * make-subseq
tensor2 vec>> n p * i * n p * make-subseq vec3 m p * i *
m p * make-subseq m n p {
{ [ n 4 mod 0 = ] [ 2d-matmul-simd ] }
{ [ n 4 < ] [ 2d-matmul ] }
[ i 2d-matmul-mixed ]
} cond
] each-integer vec3 <tensor> ;