44import { MatMul } from '../../../ops/matmul' ;
55import { Tensor } from '../../../tensor' ;
66import { BroadcastUtil } from '../../../util' ;
7+ import { getGlsl } from '../glsl-source' ;
78import { WebGLInferenceHandler } from '../inference-handler' ;
89import { ProgramInfo , RunData , WebGLOperator } from '../types' ;
10+ import { getCoordsDataType } from '../utils' ;
911
1012export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
1113 run ( inferenceHandler : WebGLInferenceHandler , inputs : Tensor [ ] ) : Tensor [ ] {
1214 return inferenceHandler . run ( this , inputs ) ;
1315 }
1416 createProgramInfo ( handler : WebGLInferenceHandler , inputs : Tensor [ ] ) : ProgramInfo {
1517 const hasBias = inputs . length > 2 ;
16- const processBias = hasBias ? `value += vec4(getBias(a[0]*2).xx, getBias(a[0]*2).yy);` : `` ;
18+ const processBias = hasBias ? 'result += getBiasAtOutCoords();' : '' ;
1719 const aShape = inputs [ 0 ] . dims ;
1820 const bShape = inputs [ 1 ] . dims ;
1921 const outputShape = BroadcastUtil . calcShape ( aShape , bShape , true ) ;
2022
2123 if ( ! outputShape ) {
2224 throw new Error ( 'Can\'t use matmul on the given tensors' ) ;
2325 }
24- const rank = outputShape . length ;
26+
27+ const sharedDim = aShape [ aShape . length - 1 ] ;
28+ const sharedDimIndex = Math . ceil ( sharedDim / 2 ) ;
29+
2530 const aRank = aShape . length ;
2631 const bRank = bShape . length ;
27- const sharedDim = aShape [ aShape . length - 1 ] ;
28- // TODO:fix broadcasting
32+
33+ const glsl = getGlsl ( handler . session . backend . glContext . version ) ;
34+ const coordsDataType = getCoordsDataType ( outputShape . length ) ;
35+ const allGlChannels = [ 'x' , 'y' , 'z' , 'w' , 'u' , 'v' ] ;
36+
2937 const shaderSource = `
30- vec4 process(int indices[${ rank } ]) {
31- int a[${ aRank } ];
32- int b[${ bRank } ];
33- bcastMatmulIndices_A(indices, a);
34- bcastMatmulIndices_B(indices, b);
38+ void main() {
39+ ${ coordsDataType } rc = getOutputCoords();
40+
41+ vec4 result = vec4(0);
42+
43+ for (int i = 0; i < ${ sharedDimIndex } ; i++) {
44+ vec4 a = getA(${ getA ( allGlChannels , aRank ) } );
45+ vec4 b = getB(${ getB ( allGlChannels , bRank ) } );
46+
47+ result += (a.rrbb * b.rgrg);
48+ result += (a.ggaa * b.baba);
49+ }
50+
51+ ${ processBias }
52+
53+ ${ glsl . output } = result;
54+ }` ;
3555
36- vec4 value;
37- for (int k=0; k<((${ sharedDim } +1)/2); ++k) {
38- a[${ aRank - 1 } ] = k;
39- b[${ bRank - 2 } ] = k;
40- value += ${ getA ( aRank ) } .rrbb * ${ getB ( bRank ) } .rgrg;
41- value += ${ getA ( aRank ) } .ggaa * ${ getB ( bRank ) } .baba;
42- }
43- ${ processBias }
44- return value;
45- }` ;
4656 return {
4757 inputLayouts : inputs . map ( ( t , i ) => handler . getOrCreateTextureLayout ( t , 4 , true , inputs [ i ] . dims , true ) ) ,
4858 outputLayout :
4959 handler . createTextureLayoutFromShape ( outputShape , 4 , outputShape , { isPacked : true , reverseWH : true } ) ,
5060 samplers : hasBias ? [ 'A' , 'B' , 'Bias' ] : [ 'A' , 'B' ] ,
5161 shaderSource,
62+ hasMain : true ,
5263 expectPackedInputs : true ,
5364 expectPackedOutputs : true ,
5465 } ;
@@ -64,22 +75,22 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
6475 }
6576}
6677
67- function getA ( outputRank : number ) : string {
68- let res = 'getA( ' ;
69- for ( let i = 0 ; i < outputRank - 2 ; i ++ ) {
70- res += `a[ ${ i } ] , ` ;
78+ function getA ( allGlChannels : string [ ] , rank : number ) : string {
79+ let res = '' ;
80+ for ( let i = 0 ; i < rank - 2 ; i ++ ) {
81+ res += `rc. ${ allGlChannels [ i ] } , ` ;
7182 }
72- res += `a[ ${ outputRank - 2 } ]*2 , ` +
73- 'k*2) ' ;
83+ res += `rc. ${ allGlChannels [ rank - 2 ] } , ` +
84+ 'i<<1 ' ;
7485 return res ;
7586}
7687
77- function getB ( outputRank : number ) : string {
78- let res = 'getB( ' ;
79- for ( let i = 0 ; i < outputRank - 2 ; i ++ ) {
80- res += `b[ ${ i } ] , ` ;
88+ function getB ( allGlChannels : string [ ] , rank : number ) : string {
89+ let res = '' ;
90+ for ( let i = 0 ; i < rank - 2 ; i ++ ) {
91+ res += `rc. ${ allGlChannels [ i ] } , ` ;
8192 }
82- res += 'k*2 , ' +
83- `b[ ${ outputRank - 1 } ]*2) ` ;
93+ res += 'i<<1 , ' +
94+ `rc. ${ allGlChannels [ rank - 1 ] } ` ;
8495 return res ;
8596}
0 commit comments