Skip to content

Commit 79ba9de

Browse files
committed
support insert batch sql
1 parent ddb48d2 commit 79ba9de

File tree

6 files changed

+240
-15
lines changed

6 files changed

+240
-15
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<modelVersion>4.0.0</modelVersion>
55
<groupId>com.codingapi.dbstream</groupId>
66
<artifactId>dbstream-driver</artifactId>
7-
<version>1.0.15</version>
7+
<version>1.0.16</version>
88

99
<url>https://github.com/codingapi/dbstream-driver</url>
1010
<name>dbstream-driver</name>

src/main/java/com/codingapi/dbstream/listener/SQLRunningState.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class SQLRunningState {
3333
* 批量模式标识
3434
*/
3535
@Getter
36-
private boolean batchMode = false;
36+
private boolean jdbcBatchMode = false;
3737

3838
/**
3939
* 当前绑定sql
@@ -112,7 +112,7 @@ public void setSql(String sql) {
112112
* @param sql 执行sql
113113
*/
114114
public void addBatch(String sql) {
115-
batchMode = true;
115+
jdbcBatchMode = true;
116116
SQLRunningParam executeParam = new SQLRunningParam();
117117
executeParam.setSql(sql);
118118
this.sqlRunningParams.add(executeParam);
@@ -198,7 +198,7 @@ public void setParam(int index, Object value) {
198198
* @return List
199199
*/
200200
public List<Object> getListParams() {
201-
if (batchMode) {
201+
if (jdbcBatchMode) {
202202
if (this.sqlRunningParams.isEmpty()) {
203203
return new ArrayList<>();
204204
}
@@ -218,7 +218,7 @@ public List<Object> getListParams() {
218218
* @return List
219219
*/
220220
public List<SQLRunningState> getBatchSQLRunningStateList() {
221-
if (this.batchMode) {
221+
if (this.jdbcBatchMode) {
222222
if (this.sqlRunningParams.isEmpty()) {
223223
return new ArrayList<>();
224224
}
@@ -304,8 +304,7 @@ public List<Map<String, Object>> getStatementGenerateKeys(DbTable dbTable) {
304304
list.add(map);
305305
}
306306
}
307-
} catch (SQLException ignored) {
308-
}
307+
} catch (SQLException ignored) {}
309308
return list;
310309
}
311310

src/main/java/com/codingapi/dbstream/listener/dbevent/DBEventListener.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void before(SQLRunningState runningState) throws SQLException {
5151
// 判断是否支持对该表的DB事件支持
5252
if (dbTable != null && DBStreamContext.getInstance().support(runningState.getDriverProperties(), dbTable)) {
5353
// 是否批量模式判断
54-
if (runningState.isBatchMode()) {
54+
if (runningState.isJdbcBatchMode()) {
5555
// 批量模式下,将获取批量的SQL执行结果数据
5656
List<SQLRunningState> runningStateList = runningState.getBatchSQLRunningStateList();
5757
for (int i = 0; i < runningStateList.size(); i++) {
@@ -114,8 +114,8 @@ public void after(SQLRunningState runningState, Object result) throws SQLExcepti
114114
// 获取事务标识信息
115115
String transactionKey = runningState.getTransactionKey();
116116
if (this.support(sql)) {
117-
// 批量模式
118-
if (runningState.isBatchMode()) {
117+
// Jdbc批量模式
118+
if (runningState.isJdbcBatchMode()) {
119119
List<SQLRunningState> runningStateList = runningState.getBatchSQLRunningStateList();
120120
int batchSize = runningStateList.size();
121121

@@ -134,7 +134,7 @@ public void after(SQLRunningState runningState, Object result) throws SQLExcepti
134134
// 清空本地缓存数据
135135
DBEventCacheContext.getInstance().remove();
136136
} else {
137-
// 非批量模式
137+
// 非Jdbc批量模式
138138
DBEventParser dataParser = DBEventCacheContext.getInstance().get();
139139
if (dataParser != null) {
140140
// 获取DB事件信息

src/main/java/com/codingapi/dbstream/parser/InsertDBEventParser.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ private void loadSelectInsertDataList() throws SQLException {
109109
this.dataList = this.executeState.query(query, queryParams);
110110
}
111111

112-
private void loadDefaultInsertDataList() throws SQLException {
113-
List<InsertSQLParser.InsertValue> values = this.sqlParser.getValues();
114-
List<Object> paramList = this.executeState.getListParams();
112+
113+
private Map<String, Object> loadDefaultInsertData(List<Object> paramList,List<InsertSQLParser.InsertValue> values) throws SQLException{
115114
Map<String, Object> data = new HashMap<>();
116115
for (int i = 0; i < columns.size(); i++) {
117116
String column = columns.get(i);
@@ -134,7 +133,23 @@ private void loadDefaultInsertDataList() throws SQLException {
134133
}
135134
data.put(column, value);
136135
}
137-
dataList.add(data);
136+
return data;
137+
}
138+
139+
private void loadDefaultInsertDataList() throws SQLException {
140+
if(this.sqlParser.isBatchInsertSQL()){
141+
List<List<InsertSQLParser.InsertValue>> valueList = this.sqlParser.getBatchValues();
142+
List<Object> paramList = this.executeState.getListParams();
143+
for (List<InsertSQLParser.InsertValue> values: valueList){
144+
Map<String, Object> data = this.loadDefaultInsertData(paramList,values);
145+
dataList.add(data);
146+
}
147+
}else {
148+
List<InsertSQLParser.InsertValue> values = this.sqlParser.getValues();
149+
List<Object> paramList = this.executeState.getListParams();
150+
Map<String, Object> data = this.loadDefaultInsertData(paramList,values);
151+
dataList.add(data);
152+
}
138153
}
139154

140155
}

src/main/java/com/codingapi/dbstream/parser/InsertSQLParser.java

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ public boolean isDefaultInsertSQL() {
7373
return valuesMatcher.find();
7474
}
7575

76+
/**
77+
* 是否为批量的insert语句类型
78+
* INSERT INTO user (id,name) VALUES (?,?),(?,?)
79+
*/
80+
public boolean isBatchInsertSQL() {
81+
if (!isDefaultInsertSQL()) {
82+
return false;
83+
}
84+
85+
String valuesSQL = getValuesSQL();
86+
if (valuesSQL == null) {
87+
return false;
88+
}
89+
90+
String normalized = valuesSQL.replaceAll("\\s+", "");
91+
return normalized.contains("),(");
92+
}
93+
94+
7695
/**
7796
* 提取 VALUES 或 SELECT 后面的 SQL 内容(包含完整结构)
7897
* 示例:
@@ -103,6 +122,100 @@ public String getValuesSQL() {
103122
return null;
104123
}
105124

125+
private List<String> splitValueGroups(String input) {
126+
List<String> groups = new ArrayList<>();
127+
128+
int level = 0;
129+
StringBuilder current = new StringBuilder();
130+
131+
for (int i = 0; i < input.length(); i++) {
132+
char c = input.charAt(i);
133+
134+
if (c == '(') {
135+
if (level > 0) {
136+
current.append(c);
137+
}
138+
level++;
139+
} else if (c == ')') {
140+
level--;
141+
if (level > 0) {
142+
current.append(c);
143+
} else {
144+
// 一个完整 group
145+
groups.add("(" + current.toString() + ")");
146+
current.setLength(0);
147+
}
148+
} else if (c == ',' && level == 0) {
149+
// group 之间的逗号,忽略
150+
continue;
151+
} else {
152+
if (level > 0) {
153+
current.append(c);
154+
}
155+
}
156+
}
157+
158+
return groups;
159+
}
160+
161+
public List<List<InsertValue>> getBatchValues() {
162+
List<List<InsertValue>> result = new ArrayList<>();
163+
164+
String valuesSQL = getValuesSQL();
165+
if (valuesSQL == null) {
166+
return result;
167+
}
168+
169+
// 去掉 VALUES 关键字
170+
String normalized = valuesSQL.trim();
171+
// 如果没有以 ( 开头,补一个
172+
if (!normalized.startsWith("(")) {
173+
normalized = "(" + normalized;
174+
}
175+
176+
// 如果没有以 ) 结尾,补一个
177+
if (!normalized.endsWith(")")) {
178+
normalized = normalized + ")";
179+
}
180+
181+
List<String> groups = splitValueGroups(normalized);
182+
183+
int jdbcIndex = 0;
184+
185+
for (String group : groups) {
186+
// 去掉外层括号
187+
String inner = group.trim();
188+
if (inner.startsWith("(") && inner.endsWith(")")) {
189+
inner = inner.substring(1, inner.length() - 1);
190+
}
191+
192+
List<String> values = SQLUtils.parseInsertSQLValues(inner);
193+
List<InsertValue> row = new ArrayList<>();
194+
195+
for (String value : values) {
196+
InsertValue insertValue = new InsertValue();
197+
198+
String v = value.trim();
199+
if ("?".equals(v)) {
200+
insertValue.setType(ValueType.JDBC);
201+
jdbcIndex++;
202+
insertValue.setValue("?" + jdbcIndex);
203+
} else if (SQLUtils.isSQLKeyword(v) || v.startsWith("(")) {
204+
insertValue.setType(ValueType.SELECT);
205+
insertValue.setValue(v);
206+
} else {
207+
insertValue.setType(ValueType.STATIC);
208+
insertValue.setValue(v);
209+
}
210+
211+
row.add(insertValue);
212+
}
213+
214+
result.add(row);
215+
}
216+
217+
return result;
218+
}
106219

107220
public List<InsertValue> getValues() {
108221
List<InsertValue> insertValues = new ArrayList<>();
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package com.example.dbstream.tests;
2+
3+
4+
import com.codingapi.dbstream.DBStreamContext;
5+
import com.codingapi.dbstream.event.DBEvent;
6+
import com.codingapi.dbstream.event.DBEventPusher;
7+
import com.codingapi.dbstream.query.JdbcQuery;
8+
import com.example.dbstream.entity.User3;
9+
import com.example.dbstream.listener.MySQLListener;
10+
import com.example.dbstream.repository.User3Repository;
11+
import org.junit.jupiter.api.Order;
12+
import org.junit.jupiter.api.Test;
13+
import org.springframework.beans.factory.annotation.Autowired;
14+
import org.springframework.boot.test.context.SpringBootTest;
15+
import org.springframework.jdbc.core.JdbcTemplate;
16+
import org.springframework.test.annotation.Rollback;
17+
18+
import javax.transaction.Transactional;
19+
import java.util.List;
20+
21+
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
23+
24+
@SpringBootTest
25+
class InsertBatchValuesTest {
26+
27+
@Autowired
28+
private JdbcTemplate jdbcTemplate;
29+
30+
@Autowired
31+
private User3Repository user3Repository;
32+
33+
/**
34+
* 常用操作测试
35+
*/
36+
@Test
37+
@Transactional
38+
@Rollback(false)
39+
@Order(1)
40+
void test1() {
41+
42+
user3Repository.deleteAll();
43+
44+
DBStreamContext.getInstance().addListener(new MySQLListener());
45+
46+
DBStreamContext.getInstance().cleanEventPushers();
47+
DBStreamContext.getInstance().addEventPusher(new DBEventPusher() {
48+
@Override
49+
public void push(JdbcQuery jdbcQuery, List<DBEvent> events) {
50+
System.out.println(events);
51+
}
52+
});
53+
54+
String insertSQL = "insert into m_user_3 (username,password,email,nickname) values ('1','1','1','1')";
55+
56+
jdbcTemplate.update(insertSQL);
57+
58+
List<User3> userList = user3Repository.findAll();
59+
60+
assertEquals(1,userList.size());
61+
62+
}
63+
64+
65+
/**
66+
* 常用操作测试
67+
*/
68+
@Test
69+
@Transactional
70+
@Rollback(false)
71+
@Order(2)
72+
void test2() {
73+
74+
user3Repository.deleteAll();
75+
76+
DBStreamContext.getInstance().addListener(new MySQLListener());
77+
78+
DBStreamContext.getInstance().cleanEventPushers();
79+
DBStreamContext.getInstance().addEventPusher(new DBEventPusher() {
80+
@Override
81+
public void push(JdbcQuery jdbcQuery, List<DBEvent> events) {
82+
System.out.println(events);
83+
}
84+
});
85+
86+
String insertSQL = "insert into m_user_3 (username,password,email,nickname) values ('1','1','1','1'),('1','1','1','1')";
87+
88+
jdbcTemplate.update(insertSQL);
89+
90+
List<User3> userList = user3Repository.findAll();
91+
92+
assertEquals(2,userList.size());
93+
94+
}
95+
96+
97+
98+
}

0 commit comments

Comments
 (0)