Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ extern "C" {
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_Q1_0 = 41,
GGML_TYPE_COUNT = 42,
GGML_TYPE_Q2_0 = 42,
GGML_TYPE_COUNT = 43,
};

// precision
Expand Down Expand Up @@ -473,6 +474,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors
GGML_FTYPE_MOSTLY_Q2_0 = 28, // except 1d tensors
};

// available tensor operations:
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ typedef sycl::half2 ggml_half2;
#define QI1_0 (QK1_0 / 32)
#define QR1_0 1

#define QI2_0 (QK2_0 / 32)
#define QR2_0 1


#define QI4_0 (QK4_0 / (4 * QR4_0))
#define QR4_0 2
Expand Down Expand Up @@ -181,6 +184,13 @@ typedef struct {
} block_q1_0;
static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding");

#define QK2_0 64
typedef struct {
ggml_half d; // delta (scale)
uint8_t qs[QK2_0 / 4]; // 2 bits per element
} block_q2_0;
static_assert(sizeof(block_q2_0) == sizeof(ggml_half) + QK2_0 / 4, "wrong q2_0 block size/padding");

#define QK4_0 32
typedef struct {
ggml_half d; // delta
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cpu/arch-fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
Expand Down Expand Up @@ -83,6 +84,7 @@
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
Expand Down Expand Up @@ -114,6 +116,7 @@
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
Expand Down Expand Up @@ -163,6 +166,7 @@
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
Expand Down Expand Up @@ -203,6 +207,7 @@
#elif defined(__riscv)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
Expand Down Expand Up @@ -244,6 +249,7 @@
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
Expand Down Expand Up @@ -308,6 +314,7 @@
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
Expand Down
74 changes: 74 additions & 0 deletions ggml/src/ggml-cpu/arch/arm/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,80 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}

void ggml_vec_dot_q2_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK2_0;
const int nb = n / qk;

assert(n % qk == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);

const block_q2_0 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;

float sumf = 0.0f;

#if defined(__ARM_NEON)
// Replicate pattern: each byte repeated 4 times
static const uint8_t tbl_idx_lo[16] = {0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3};
static const uint8_t tbl_idx_hi[16] = {4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7};
// Right-shift amounts: 0,2,4,6 repeated for each group of 4
static const int8_t shift_vals[16] = {0,-2,-4,-6, 0,-2,-4,-6, 0,-2,-4,-6, 0,-2,-4,-6};

const uint8x16_t idx_lo = vld1q_u8(tbl_idx_lo);
const uint8x16_t idx_hi = vld1q_u8(tbl_idx_hi);
const int8x16_t shifts = vld1q_s8(shift_vals);
const uint8x16_t mask2 = vdupq_n_u8(0x03);
const int8x16_t one = vdupq_n_s8(1);

float32x4_t sumv = vdupq_n_f32(0.0f);

for (int i = 0; i < nb; i++) {
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);

// group 64: one Q2_0 block (64 weights) maps to two Q8_0 blocks (2 * 32 = 64)
for (int k = 0; k < 2; k++) {
const block_q8_0 * GGML_RESTRICT yb = &y[i * 2 + k];
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);

// Load 8 bytes of packed 2-bit values
const uint8x8_t raw = vld1_u8(&x[i].qs[k * 8]);
const uint8x16_t raw16 = vcombine_u8(raw, raw);

// First 16 elements: replicate bytes 0-3, shift, mask, subtract 1
uint8x16_t bytes0 = vqtbl1q_u8(raw16, idx_lo);
int8x16_t qv0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshlq_u8(bytes0, shifts), mask2)),
one);

// Second 16 elements: replicate bytes 4-7, shift, mask, subtract 1
uint8x16_t bytes1 = vqtbl1q_u8(raw16, idx_hi);
int8x16_t qv1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshlq_u8(bytes1, shifts), mask2)),
one);

// Load Q8_0 values and dot product
const int8x16_t y0 = vld1q_s8(yb->qs);
const int8x16_t y1 = vld1q_s8(yb->qs + 16);

int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), qv0, y0);
int32x4_t p1 = ggml_vdotq_s32(p0, qv1, y1);

sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1);
}
}

sumf = vaddvq_f32(sumv);
#else
ggml_vec_dot_q2_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
return;
#endif

*s = sumf;
}

void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_0] = {
.from_float = quantize_row_q2_0,
.vec_dot = ggml_vec_dot_q2_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
Comment thread
khosravipasha marked this conversation as resolved.
[GGML_TYPE_Q4_0] = {
.from_float = quantize_row_q4_0,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ void ggml_compute_forward_add(
ggml_compute_forward_add_non_quantized(params, dst);
} break;
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -1115,6 +1116,7 @@ void ggml_compute_forward_add1(
}
} break;
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -1245,6 +1247,7 @@ void ggml_compute_forward_acc(
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -4415,6 +4418,7 @@ void ggml_compute_forward_out_prod(

switch (src0->type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -4691,6 +4695,7 @@ void ggml_compute_forward_set(
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -4915,6 +4920,7 @@ void ggml_compute_forward_get_rows(

switch (src0->type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -5641,6 +5647,7 @@ void ggml_compute_forward_clamp(
} break;
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q2_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down
51 changes: 51 additions & 0 deletions ggml/src/ggml-cpu/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
quantize_row_q1_0_ref(x, y, k);
}

void quantize_row_q2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_q2_0_ref(x, y, k);
}

void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_q4_0_ref(x, y, k);
}
Expand Down Expand Up @@ -170,6 +174,53 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
*s = sumf;
}

void ggml_vec_dot_q2_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK2_0;
const int nb = n / qk;

assert(n % qk == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);

const block_q2_0 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;

float sumf = 0.0f;

for (int i = 0; i < nb; i++) {
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);

float sumi = 0.0f;

// group 64: one Q2_0 block (64 weights) maps to two Q8_0 blocks (2 * 32 = 64)
for (int k = 0; k < 2; k++) {
const block_q8_0 * GGML_RESTRICT yb = &y[i * 2 + k];
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);
int sumi_block = 0;

const uint8_t * GGML_RESTRICT qs = &x[i].qs[k * 8];
const int8_t * GGML_RESTRICT qy = yb->qs;

for (int b = 0; b < 8; ++b) {
const uint8_t byte = qs[b];
// Extract 4 two-bit values, map {0,1,2,3} -> {-1,0,1,2}
sumi_block += ((int)((byte >> 0) & 3) - 1) * qy[b*4 + 0];
sumi_block += ((int)((byte >> 2) & 3) - 1) * qy[b*4 + 1];
sumi_block += ((int)((byte >> 4) & 3) - 1) * qy[b*4 + 2];
sumi_block += ((int)((byte >> 6) & 3) - 1) * qy[b*4 + 3];
}

sumi += d1 * sumi_block;
}

sumf += d0 * sumi;
}

*s = sumf;
}

void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-cpu/quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern "C" {

// Quantization
void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
Expand All @@ -38,6 +39,7 @@ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,

// Dot product
void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
Expand Down Expand Up @@ -71,6 +73,7 @@ void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
Expand Down
Loading