Skip to content

Commit f914828

Browse files
authored
Merge pull request #21816 from github/tausbn/yeast-mutate-in-place
yeast: Two minor performance optimisations
2 parents 36554d1 + 15936a5 commit f914828

3 files changed

Lines changed: 62 additions & 71 deletions

File tree

shared/yeast/src/lib.rs

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -34,44 +34,48 @@ pub const CHILD_FIELD: u16 = u16::MAX;
3434
#[derive(Debug)]
3535
pub struct AstCursor<'a> {
3636
ast: &'a Ast,
37-
/// A stack of parents, along with iterators for their children
38-
parents: Vec<(&'a Node, ChildrenIter<'a>)>,
39-
node: &'a Node,
37+
/// A stack of parents, along with iterators for their children.
38+
parents: Vec<(Id, ChildrenIter<'a>)>,
39+
node_id: Id,
4040
}
4141

4242
impl<'a> AstCursor<'a> {
4343
pub fn new(ast: &'a Ast) -> Self {
44-
// TODO: handle non-zero root
45-
let node = ast.get_node(ast.root).unwrap();
4644
Self {
4745
ast,
4846
parents: vec![],
49-
node,
47+
node_id: ast.root,
5048
}
5149
}
5250

51+
/// The Id of the node currently under the cursor.
52+
pub fn node_id(&self) -> Id {
53+
self.node_id
54+
}
55+
5356
fn goto_next_sibling_opt(&mut self) -> Option<()> {
54-
self.node = self.parents.last_mut()?.1.next()?;
57+
self.node_id = self.parents.last_mut()?.1.next()?;
5558
Some(())
5659
}
5760

5861
fn goto_first_child_opt(&mut self) -> Option<()> {
59-
let parent = self.node;
60-
let mut children = ChildrenIter::new(self.ast, parent);
62+
let parent_id = self.node_id;
63+
let parent = self.ast.get_node(parent_id)?;
64+
let mut children = ChildrenIter::new(parent);
6165
let first_child = children.next()?;
62-
self.node = first_child;
63-
self.parents.push((parent, children));
66+
self.node_id = first_child;
67+
self.parents.push((parent_id, children));
6468
Some(())
6569
}
6670

6771
fn goto_parent_opt(&mut self) -> Option<()> {
68-
self.node = self.parents.pop()?.0;
72+
self.node_id = self.parents.pop()?.0;
6973
Some(())
7074
}
7175
}
7276
impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> {
7377
fn node(&self) -> &'a Node {
74-
self.node
78+
&self.ast.nodes[self.node_id]
7579
}
7680

7781
fn field_id(&self) -> Option<FieldId> {
@@ -101,36 +105,30 @@ impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> {
101105
}
102106
}
103107

104-
/// An iterator over all the child nodes of a node.
108+
/// An iterator over the child Ids of a node.
105109
#[derive(Debug)]
106110
struct ChildrenIter<'a> {
107-
ast: &'a Ast,
108111
current_field: Option<FieldId>,
109112
fields: std::collections::btree_map::Iter<'a, FieldId, Vec<Id>>,
110113
field_children: Option<std::slice::Iter<'a, Id>>,
111114
}
112115

113116
impl<'a> ChildrenIter<'a> {
114-
fn new(ast: &'a Ast, node: &'a Node) -> Self {
117+
fn new(node: &'a Node) -> Self {
115118
Self {
116-
ast,
117119
current_field: None,
118120
fields: node.fields.iter(),
119121
field_children: None,
120122
}
121123
}
122124

123-
fn get_node(&self, id: Id) -> &'a Node {
124-
self.ast.get_node(id).unwrap()
125-
}
126-
127125
fn current_field(&self) -> Option<FieldId> {
128126
self.current_field
129127
}
130128
}
131129

132-
impl<'a> Iterator for ChildrenIter<'a> {
133-
type Item = &'a Node;
130+
impl Iterator for ChildrenIter<'_> {
131+
type Item = Id;
134132

135133
fn next(&mut self) -> Option<Self::Item> {
136134
match self.field_children.as_mut() {
@@ -151,7 +149,7 @@ impl<'a> Iterator for ChildrenIter<'a> {
151149
self.next()
152150
}
153151
},
154-
Some(child_id) => Some(self.get_node(*child_id)),
152+
Some(child_id) => Some(*child_id),
155153
},
156154
}
157155
}
@@ -236,7 +234,6 @@ impl Ast {
236234
) -> Id {
237235
let id = self.nodes.len();
238236
self.nodes.push(Node {
239-
id,
240237
kind,
241238
kind_name: self.schema.node_kind_for_id(kind).unwrap(),
242239
fields,
@@ -265,7 +262,6 @@ impl Ast {
265262
});
266263
let id = self.nodes.len();
267264
self.nodes.push(Node {
268-
id,
269265
kind: kind_id,
270266
kind_name: kind,
271267
is_named: true,
@@ -345,7 +341,6 @@ impl Ast {
345341
/// A node in our AST
346342
#[derive(PartialEq, Eq, Debug, Clone, Serialize)]
347343
pub struct Node {
348-
id: Id,
349344
kind: KindId,
350345
kind_name: &'static str,
351346
pub(crate) fields: BTreeMap<FieldId, Vec<Id>>,
@@ -361,10 +356,6 @@ pub struct Node {
361356
}
362357

363358
impl Node {
364-
pub fn id(&self) -> Id {
365-
self.id
366-
}
367-
368359
pub fn kind(&self) -> &'static str {
369360
self.kind_name
370361
}
@@ -600,39 +591,41 @@ fn apply_rules_inner(
600591
}
601592
}
602593

603-
// Collect fields before recursing (avoids borrowing ast immutably during mutation)
604-
let field_entries: Vec<(FieldId, Vec<Id>)> = ast.nodes[id]
605-
.fields
606-
.iter()
607-
.map(|(&fid, children)| (fid, children.clone()))
608-
.collect();
609-
610-
// recursively descend into all the fields
594+
// Take the parent's fields by ownership: the recursion will rewrite
595+
// each child Id, and we'll write the (possibly mutated) field map back
596+
// when we're done. Avoids cloning the whole BTreeMap and its child
597+
// Vecs on entry. Each child Vec is only re-allocated if a rewrite
598+
// actually changes its contents.
599+
//
611600
// Child traversal does not increment rewrite depth and starts fresh
612601
// (no rule is skipped on child subtrees).
613-
let mut changed = false;
614-
let mut new_fields = BTreeMap::new();
615-
for (field_id, children) in field_entries {
616-
let mut new_children = Vec::new();
617-
for child_id in children {
602+
let mut fields = std::mem::take(&mut ast.nodes[id].fields);
603+
for children in fields.values_mut() {
604+
let mut new_children: Option<Vec<Id>> = None;
605+
for (i, &child_id) in children.iter().enumerate() {
618606
let result = apply_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?;
619-
if result.len() != 1 || result[0] != child_id {
620-
changed = true;
607+
let unchanged = result.len() == 1 && result[0] == child_id;
608+
match (&mut new_children, unchanged) {
609+
(None, true) => {} // unchanged so far, no allocation needed
610+
(None, false) => {
611+
// First divergence — copy already-processed Ids and
612+
// start collecting the rewritten sequence.
613+
let mut new = Vec::with_capacity(children.len());
614+
new.extend_from_slice(&children[..i]);
615+
new.extend(result);
616+
new_children = Some(new);
617+
}
618+
(Some(new), _) => {
619+
new.extend(result);
620+
}
621621
}
622-
new_children.extend(result);
623622
}
624-
new_fields.insert(field_id, new_children);
625-
}
626-
627-
if !changed {
628-
return Ok(vec![id]);
623+
if let Some(new) = new_children {
624+
*children = new;
625+
}
629626
}
630-
631-
let mut node = ast.nodes[id].clone();
632-
node.fields = new_fields;
633-
node.id = ast.nodes.len();
634-
ast.nodes.push(node);
635-
Ok(vec![ast.nodes.len() - 1])
627+
ast.nodes[id].fields = fields;
628+
Ok(vec![id])
636629
}
637630

638631
/// One phase of a desugaring pass: a named bundle of rules that runs to

shared/yeast/src/visitor.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ impl Visitor {
4949

5050
pub fn build_with_schema(self, schema: crate::schema::Schema) -> Ast {
5151
Ast {
52-
root: self.nodes[0].inner.id,
52+
root: 0,
5353
schema,
5454
nodes: self.nodes.into_iter().map(|n| n.inner).collect(),
5555
}
@@ -59,7 +59,6 @@ impl Visitor {
5959
let id = self.nodes.len();
6060
self.nodes.push(VisitorNode {
6161
inner: Node {
62-
id,
6362
kind: self.language.id_for_node_kind(n.kind(), is_named),
6463
kind_name: n.kind(),
6564
content,
@@ -82,11 +81,10 @@ impl Visitor {
8281
}
8382

8483
fn leave_node(&mut self, field_name: Option<&'static str>, _node: tree_sitter::Node<'_>) {
85-
let node = self.current.map(|i| &self.nodes[i]).unwrap();
86-
let node_id = node.inner.id;
87-
let node_parent = node.parent;
84+
let node_id = self.current.unwrap();
85+
let node_parent = self.nodes[node_id].parent;
8886

89-
if let Some(parent_id) = node.parent {
87+
if let Some(parent_id) = node_parent {
9088
let parent = self.nodes.get_mut(parent_id).unwrap();
9189
if let Some(field) = field_name {
9290
let field_id = self.language.field_id_for_name(field).unwrap().get();

shared/yeast/tests/test.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ fn test_query_repeated_capture() {
182182
// Match against the assignment node (first named child of program)
183183
let mut cursor = AstCursor::new(&ast);
184184
cursor.goto_first_child();
185-
let assignment_id = cursor.node().id();
185+
let assignment_id = cursor.node_id();
186186

187187
let mut captures = yeast::captures::Captures::new();
188188
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
@@ -206,7 +206,7 @@ fn test_capture_unnamed_node_parenthesized() {
206206

207207
let mut cursor = AstCursor::new(&ast);
208208
cursor.goto_first_child();
209-
let assignment_id = cursor.node().id();
209+
let assignment_id = cursor.node_id();
210210

211211
let mut captures = yeast::captures::Captures::new();
212212
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
@@ -233,7 +233,7 @@ fn test_capture_unnamed_node_bare_literal() {
233233

234234
let mut cursor = AstCursor::new(&ast);
235235
cursor.goto_first_child();
236-
let assignment_id = cursor.node().id();
236+
let assignment_id = cursor.node_id();
237237

238238
let mut captures = yeast::captures::Captures::new();
239239
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
@@ -254,7 +254,7 @@ fn test_bare_underscore_matches_unnamed() {
254254

255255
let mut cursor = AstCursor::new(&ast);
256256
cursor.goto_first_child();
257-
let assignment_id = cursor.node().id();
257+
let assignment_id = cursor.node_id();
258258

259259
// `(_)` skips unnamed children, so a query containing a single `(_)`
260260
// bare pattern fails to match the assignment (whose only unfielded
@@ -293,7 +293,7 @@ fn test_bare_forms_in_field_position() {
293293

294294
let mut cursor = AstCursor::new(&ast);
295295
cursor.goto_first_child();
296-
let assignment_id = cursor.node().id();
296+
let assignment_id = cursor.node_id();
297297

298298
// Bare `_` in field position. Captures the named `identifier "x"`
299299
// child of the `left` field — bare `_` admits unnamed too, but the
@@ -337,7 +337,7 @@ fn test_forward_scan_finds_unnamed_token_late() {
337337
while cursor.node().kind() != "do" || !cursor.node().is_named() {
338338
assert!(cursor.goto_next_sibling(), "expected to find named `do`");
339339
}
340-
let do_id = cursor.node().id();
340+
let do_id = cursor.node_id();
341341

342342
let query = yeast::query!((do ("end") @kw));
343343
let mut captures = yeast::captures::Captures::new();
@@ -363,7 +363,7 @@ fn test_forward_scan_preserves_order() {
363363
while cursor.node().kind() != "do" || !cursor.node().is_named() {
364364
assert!(cursor.goto_next_sibling(), "expected to find named `do`");
365365
}
366-
let do_id = cursor.node().id();
366+
let do_id = cursor.node_id();
367367

368368
let query = yeast::query!((do ("end") @first ("do") @second));
369369
let mut captures = yeast::captures::Captures::new();

0 commit comments

Comments
 (0)