Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions crates/postgresql-cst-parser/src/tree_sitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ pub struct Node<'a> {
pub node_or_token: NodeOrToken<'a>,
}

impl<'a> PartialEq for Node<'a> {
fn eq(&self, other: &Self) -> bool {
self.node_or_token == other.node_or_token
}
}

impl<'a> Eq for Node<'a> {}

#[derive(Debug, Clone)]
pub struct TreeCursor<'a> {
pub input: &'a str,
Expand Down Expand Up @@ -98,6 +106,28 @@ impl std::fmt::Display for Range {
}
}

impl Range {
pub fn extended_by(&self, other: &Self) -> Self {
Range {
start_byte: self.start_byte.min(other.start_byte),
end_byte: self.end_byte.max(other.end_byte),

start_position: Point {
row: self.start_position.row.min(other.start_position.row),
column: self.start_position.column.min(other.start_position.column),
},
end_position: Point {
row: self.end_position.row.max(other.end_position.row),
column: self.end_position.column.max(other.end_position.column),
},
}
}

pub fn is_adjacent(&self, other: &Self) -> bool {
self.end_byte == other.start_byte || self.start_byte == other.end_byte
}
}

impl<'a> Node<'a> {
pub fn walk(&self) -> TreeCursor<'a> {
TreeCursor {
Expand Down Expand Up @@ -144,6 +174,48 @@ impl<'a> Node<'a> {
}
}

pub fn children(&self) -> Vec<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.children_with_tokens()
.map(|node| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: node,
})
.collect()
} else {
vec![]
}
}

/// Returns the first child element of this node.
/// this is not tree-sitter's API
pub fn first_child(&self) -> Option<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.first_child_or_token().map(|child| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: child,
})
} else {
None
}
}

/// Returns the last child element of this node.
/// this is not tree-sitter's API
pub fn last_child(&self) -> Option<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.last_child_or_token().map(|child| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: child,
})
} else {
None
}
}

pub fn next_sibling(&self) -> Option<Node<'a>> {
self.node_or_token
.next_sibling_or_token()
Expand All @@ -154,6 +226,16 @@ impl<'a> Node<'a> {
})
}

pub fn prev_sibling(&self) -> Option<Node<'a>> {
self.node_or_token
.prev_sibling_or_token()
.map(|sibling| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: sibling,
})
}

pub fn parent(&self) -> Option<Node<'a>> {
self.node_or_token.parent().map(|parent| Node {
input: self.input,
Expand All @@ -165,6 +247,82 @@ impl<'a> Node<'a> {
pub fn is_comment(&self) -> bool {
matches!(self.kind(), SyntaxKind::C_COMMENT | SyntaxKind::SQL_COMMENT)
}

/// Returns the rightmost token in the subtree of this node.
/// this is not tree-sitter's API
pub fn last_node(&self) -> Option<Node<'a>> {
match &self.node_or_token {
NodeOrToken::Node(node) => node.last_token().map(|token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(token),
}),
NodeOrToken::Token(token) => Some(Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(token),
}),
}
}

/// Returns the next token in the tree.
/// This is not necessarily a direct sibling of this node/token,
/// but will always be further right in the tree.
/// this is not tree-sitter's API
pub fn next_token(&self) -> Option<Node<'a>> {
match &self.node_or_token {
NodeOrToken::Token(token) => token.next_token().map(|next_token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(next_token),
}),
NodeOrToken::Node(node) => {
// For a node, find its last token and then get the next token
node.last_token()
.and_then(|last_token| last_token.next_token())
.map(|next_token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(next_token),
})
}
}
}

/// Returns an iterator over all descendant nodes (including tokens)
/// this is not tree-sitter's API
pub fn descendants(&self) -> impl Iterator<Item = Node<'a>> + '_ {
struct Descendants<'a> {
iter: Box<dyn Iterator<Item = Node<'a>> + 'a>,
}

impl<'a> Iterator for Descendants<'a> {
type Item = Node<'a>;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

if let Some(node) = self.node_or_token.as_node() {
let input = self.input;
let range_map = Rc::clone(&self.range_map);
Descendants {
iter: Box::new(
node.descendants_with_tokens()
.map(move |node_or_token| Node {
input,
range_map: Rc::clone(&range_map),
node_or_token,
}),
),
}
} else {
Descendants {
iter: Box::new(std::iter::empty()),
}
}
}
}

impl<'a> From<Node<'a>> for TreeCursor<'a> {
Expand Down Expand Up @@ -214,6 +372,15 @@ impl<'a> TreeCursor<'a> {
}
}

pub fn goto_prev_sibling(&mut self) -> bool {
if let Some(sibling) = self.node_or_token.prev_sibling_or_token() {
self.node_or_token = sibling;
true
} else {
false
}
}

pub fn is_comment(&self) -> bool {
matches!(
self.node_or_token.kind(),
Expand Down Expand Up @@ -462,4 +629,53 @@ from

assert_eq!(stmt_count, 2);
}

#[test]
fn test_last_node_returns_rightmost_node() {
let src = "SELECT u.*, (v).id, name;";
let tree = parse(src).unwrap();
let root = tree.root_node();

let target_list = root
.descendants()
.find(|node| node.kind() == SyntaxKind::target_list)
.expect("should find target_list");

// last node of the target_list is returned
let last_node = target_list.last_node().expect("should have last node");
assert_eq!(last_node.text(), "name");

let target_els = target_list
.children()
.into_iter()
.filter(|node| node.kind() == SyntaxKind::target_el)
.collect::<Vec<_>>();

let mut last_nodes = target_els
.iter()
.map(|node| node.last_node().expect("should have last node"));

// last node of each target_el is returned
assert_eq!(last_nodes.next().unwrap().text(), "*");
assert_eq!(last_nodes.next().unwrap().text(), "id");
assert_eq!(last_nodes.next().unwrap().text(), "name");
assert!(last_nodes.next().is_none());
}

#[test]
fn test_next_token() {
let src = "SELECT tbl.name as n from TBL;";
let tree = parse(src).unwrap();
let root = tree.root_node();

let name = root
.descendants()
.find(|node| node.kind() == SyntaxKind::NAME_P)
.expect("should find NAME_P");

// Even if not a direct sibling or not belonging to the same subtree, the next_token can retrieve the next token.
let next_token = name.next_token().expect("should have next token");
assert_eq!(next_token.text(), "as");
assert_eq!(next_token.kind(), SyntaxKind::AS);
}
}
Loading