Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [

```mlir
vector-type ::= !cir.vector<size x element-type>
size ::= (decimal-literal | `[` decimal-literal `]`)
element-type ::= float-type | integer-type | pointer-type
```

Expand All @@ -442,6 +443,13 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
!cir.vector<4 x !cir.int<u, 8>>
!cir.vector<2 x !cir.float>
```

Scalable vectors are indicated by enclosing size in square brackets.

Example:
```mlir
!cir.vector<[4] x !cir.int<u, 8>>
```
}];

let parameters = (ins
Expand All @@ -450,10 +458,6 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [
OptionalParameter<"bool">:$is_scalable
);

let assemblyFormat = [{
`<` $size `x` $element_type `>`
}];

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$element_type, "uint64_t":$size, CArg<"bool",
Expand All @@ -471,6 +475,7 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", [

let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 69 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,75 @@ mlir::LogicalResult cir::VectorType::verify(
return success();
}

mlir::Type cir::VectorType::parse(::mlir::AsmParser &odsParser) {

llvm::SMLoc odsLoc = odsParser.getCurrentLocation();
mlir::Builder odsBuilder(odsParser.getContext());
mlir::FailureOr<::mlir::Type> elementType;
mlir::FailureOr<uint64_t> size;
bool isScalabe = false;

// Parse literal '<'
if (odsParser.parseLess())
return {};

// Parse literal '[', if present, and set the scalability flag accordingly
if (odsParser.parseOptionalLSquare().succeeded())
isScalabe = true;

// Parse variable 'size'
size = mlir::FieldParser<uint64_t>::parse(odsParser);
if (mlir::failed(size)) {
odsParser.emitError(odsParser.getCurrentLocation(),
"failed to parse CIR_VectorType parameter 'size' which "
"is to be a `uint64_t`");
return {};
}

// Parse literal ']', which is expected when dealing with scalable
// dim sizes
if (isScalabe && odsParser.parseRSquare().failed()) {
odsParser.emitError(odsParser.getCurrentLocation(),
"missing closing `]` for scalable dim size");
return {};
}

// Parse literal 'x'
if (odsParser.parseKeyword("x"))
return {};

// Parse variable 'elementType'
elementType = mlir::FieldParser<::mlir::Type>::parse(odsParser);
if (mlir::failed(elementType)) {
odsParser.emitError(odsParser.getCurrentLocation(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an invalid vector cir test to make sure errors are emitted correctly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be different in CIR, but in MLIR we don't really test the parser and tests in invalid.mlir (and similar) are reserved for verification errors.

I am happy to add a test if you think that that would be helpful, but we’d probably want to add a dedicated file for parser errors - perhaps one already exists? I didn’t find any.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, for the handwritten parser or verifier, we have a invalid-<>.cir file to test it, for example, clang/test/CIR/IR/invalid-vector.cir, I think we can add a small test for scalable vector type syntax. What do you think? @andykaylor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, Only this error need to be covered "missing closing ] for scalable dim size", the other already covered in this test :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test in this commit.

IMO, it's not great :/ (suggestions for improvement are welcome) .Testing !cir.vector<[1 x !s32i> (instead of !cir.vector<[1 x] !s32i>) would be better, but the former is captured by Parser::parseDialectSymbolBody with:

/llvm-project/clang/test/CIR/IR/invalid-vector.cir:17:30: error: unbalanced '[' character in pretty dialect name
  %3 = cir.alloca !cir.vector<[1 x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>
                             ^

That error is hit before getting into cir::VectorType::parse.

Btw, I don't want to come across as nit-picking or pushing back, but I see quite a few CIR parser errors that are not tested, e.g.

parser.emitError(loc, "invalid self-reference within record");

Perhaps it would be better to skip testing in this case as well? WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it. I am fine with both options, but I think it will be better to keep this test case 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I don't want to come across as nit-picking or pushing back, but I see quite a few CIR parser errors that are not tested, e.g.

I'm sure we've been inconsistent about this. I generally only look for tests for verifier errors and have been satisfied with round-trip tests for printing/parsing. My view is that the tests for the verifier check that we are correctly catching incorrectly constructed IR, which can occur anywhere during IR generation or processing, whereas the printing and parsing are paired so they test each other for correctness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets me keep the test for now, we can always remove it later.

"failed to parse CIR_VectorType parameter "
"'elementType' which is to be a `mlir::Type`");
return {};
}

// Parse literal '>'
if (odsParser.parseGreater())
return {};
return odsParser.getChecked<VectorType>(odsLoc, odsParser.getContext(),
mlir::Type((*elementType)),
uint64_t((*size)), isScalabe);
}

void cir::VectorType::print(mlir::AsmPrinter &odsPrinter) const {
mlir::Builder odsBuilder(getContext());
odsPrinter << "<";
if (this->getIsScalable())
odsPrinter << "[";

odsPrinter.printStrippedAttrOrType(getSize());
if (this->getIsScalable())
odsPrinter << "]";
odsPrinter << ' ' << "x";
odsPrinter << ' ';
odsPrinter.printStrippedAttrOrType(getElementType());
odsPrinter << ">";
}

//===----------------------------------------------------------------------===//
// TargetAddressSpace definitions
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions clang/test/CIR/IR/invalid-vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>

// -----

cir.func @invalid_scalable_vec_syntax() {
// expected-error@+2 {{missing closing `]` for scalable dim size}}
// expected-error @below {{expected ']'}}
%3 = cir.alloca !cir.vector<[1 x] !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>
cir.return

}

// -----

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

Expand Down
2 changes: 2 additions & 0 deletions clang/test/CIR/IR/vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ cir.func @vec_int_test() {
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
%1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
%2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
%3 = cir.alloca !cir.vector<[1] x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>, ["d"]
cir.return
}

// CHECK: cir.func{{.*}} @vec_int_test() {
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CHECK: %1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"]
// CHECK: %2 = cir.alloca !cir.vector<2 x !s32i>, !cir.ptr<!cir.vector<2 x !s32i>>, ["c"]
// CHECK: %3 = cir.alloca !cir.vector<[1] x !s32i>, !cir.ptr<!cir.vector<[1] x !s32i>>, ["d"]
// CHECK: cir.return
// CHECK: }

Expand Down