diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 622b706330..9de850442f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4714,28 +4714,26 @@ def _set_result(self, host, connection, pool, response): self.query, cl, error=response, retry_num=self._query_retries) elif isinstance(response, PreparedQueryNotFound): + query_id = response.info + if self.prepared_statement: - query_id = self.prepared_statement.query_id - assert query_id == response.info, \ - "Got different query ID in server response (%s) than we " \ - "had before (%s)" % (response.info, query_id) - else: - query_id = response.info + # Cache local in-flight context first so lookup by either id can succeed. + self.session.cluster.add_prepared(self.prepared_statement.query_id, self.prepared_statement) try: - prepared_statement = self.session.cluster._prepared_statements[query_id] + self.prepared_statement = self.session.cluster._prepared_statements[query_id] except KeyError: if not self.prepared_statement: - log.error("Tried to execute unknown prepared statement: id=%s", - query_id.encode('hex')) + log.error("Tried to execute unknown prepared statement: id=%s", hexlify(query_id)) self._set_final_exception(response) return - else: - prepared_statement = self.prepared_statement - self.session.cluster._prepared_statements[query_id] = prepared_statement + log.warning( + "UNPREPARED for query id %s while executing statement id %s. " + "Could not resolve returned id in cache, proceeding with in-flight context.", + hexlify(query_id), hexlify(self.prepared_statement.query_id)) current_keyspace = self._connection.keyspace - prepared_keyspace = prepared_statement.keyspace + prepared_keyspace = self.prepared_statement.keyspace if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \ and prepared_keyspace and current_keyspace != prepared_keyspace: self._set_final_exception( @@ -4745,11 +4743,13 @@ def _set_result(self, host, connection, pool, response): (current_keyspace, prepared_keyspace))) return - log.debug("Re-preparing unrecognized prepared statement against host %s: %s", - host, prepared_statement.query_string) - prepared_keyspace = prepared_statement.keyspace \ + log.debug( + "Re-preparing unrecognized prepared statement against host %s: %s", + host, self.prepared_statement.query_string + ) + prepared_keyspace = self.prepared_statement.keyspace \ if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None - prepare_message = PrepareMessage(query=prepared_statement.query_string, + prepare_message = PrepareMessage(query=self.prepared_statement.query_string, keyspace=prepared_keyspace) # since this might block, run on the executor to avoid hanging # the event loop thread diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 7168ad2940..679a0ad4fe 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -621,6 +621,66 @@ def test_prepared_query_not_found_bad_keyspace(self): with pytest.raises(ValueError): rf.result() + def test_prepared_query_not_found_uses_local_prepared_context(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + + session.cluster.protocol_version = ProtocolVersion.V4 + session.cluster._prepared_statements = {} + rf._connection.keyspace = "FooKeyspace" + + rf.prepared_statement = Mock() + rf.prepared_statement.query_id = b"known-query-id" + rf.prepared_statement.query_string = "SELECT * FROM foobar" + rf.prepared_statement.keyspace = "FooKeyspace" + + # Different query id in UNPREPARED response should not prevent reprepare when local context exists. + result = Mock(spec=PreparedQueryNotFound, info=b"other-query-id") + rf._set_result(None, None, None, result) + + assert session.submit.call_args + args, _ = session.submit.call_args + assert rf._reprepare == args[-5] + assert isinstance(args[-4], PrepareMessage) + assert args[-4].query == "SELECT * FROM foobar" + + def test_prepared_query_not_found_prefers_returned_id_from_cache(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + + session.cluster.protocol_version = ProtocolVersion.V4 + rf._connection.keyspace = "FooKeyspace" + + rf.prepared_statement = Mock() + rf.prepared_statement.query_id = b"local-id" + rf.prepared_statement.query_string = "SELECT * FROM local_ctx" + rf.prepared_statement.keyspace = "FooKeyspace" + + cached_stmt = Mock() + cached_stmt.query_id = b"returned-id" + cached_stmt.query_string = "SELECT * FROM returned_ctx" + cached_stmt.keyspace = "FooKeyspace" + session.cluster._prepared_statements = {cached_stmt.query_id: cached_stmt} + + result = Mock(spec=PreparedQueryNotFound, info=cached_stmt.query_id) + rf._set_result(None, None, None, result) + + assert session.submit.call_args + args, _ = session.submit.call_args + assert rf._reprepare == args[-5] + assert isinstance(args[-4], PrepareMessage) + assert args[-4].query == "SELECT * FROM returned_ctx" + def test_repeat_orig_query_after_succesful_reprepare(self): query_id = b'abc123' # Just a random binary string so we don't hit id mismatch exception session = self.make_session()