Skip to content
5 changes: 4 additions & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5006,7 +5006,10 @@ def _query(self, host, message=None, cb=None):
try:
# TODO get connectTimeout from cluster settings
if self.query:
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table)
connection, request_id = pool.borrow_connection(
timeout=2.0, routing_key=self.query.routing_key,
keyspace=self.query.keyspace, table=self.query.table,
tablet=getattr(self.query, '_tablet', None))
else:
connection, request_id = pool.borrow_connection(timeout=2.0)
self._connection = connection
Expand Down
8 changes: 6 additions & 2 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,14 @@ def make_query_plan(self, working_keyspace=None, query=None):
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))

if tablet is not None:
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
replica_dict = tablet._replica_dict
child_plan = child.make_query_plan(keyspace, query)

replicas = [host for host in child_plan if host.host_id in replicas_mapped]
replicas = [host for host in child_plan if host.host_id in replica_dict]
# Stash the tablet so that downstream shard-aware
# connection selection can reuse it instead of
# repeating the bisect lookup.
query._tablet = tablet
else:
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)

Expand Down
26 changes: 14 additions & 12 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def __init__(self, host, host_distance, session):

log.debug("Finished initializing connection for host %s", self.host)

def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None):
def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None, tablet=None):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)
Expand All @@ -454,16 +454,18 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table

shard_id = None
if self.tablets_routing_v1 and table is not None:
if keyspace is None:
keyspace = self._keyspace
# Reuse tablet from query planning if available, avoiding
# a redundant bisect lookup in the tablet map.
if tablet is not None:
shard_id = tablet._replica_dict.get(self.host.host_id)
else:
if keyspace is None:
keyspace = self._keyspace

tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t)
tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t)

if tablet is not None:
for replica in tablet.replicas:
if replica[0] == self.host.host_id:
shard_id = replica[1]
break
if tablet is not None:
shard_id = tablet._replica_dict.get(self.host.host_id)

if shard_id is None:
shard_id = self.host.sharding_info.shard_id_from_token(t.value)
Expand Down Expand Up @@ -506,15 +508,15 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
return random.choice(active_connections)
return random.choice(list(self._connections.values()))

def borrow_connection(self, timeout, routing_key=None, keyspace=None, table=None):
conn = self._get_connection_for_routing_key(routing_key, keyspace, table)
def borrow_connection(self, timeout, routing_key=None, keyspace=None, table=None, tablet=None):
conn = self._get_connection_for_routing_key(routing_key, keyspace, table, tablet)
start = time.time()
remaining = timeout
last_retry = False
while True:
if conn.is_closed:
# The connection might have been closed in the meantime - if so, try again
conn = self._get_connection_for_routing_key(routing_key, keyspace, table)
conn = self._get_connection_for_routing_key(routing_key, keyspace, table, tablet)
with conn.lock:
if (not conn.is_closed or last_retry) and conn.in_flight < conn.max_request_id:
# On last retry we ignore connection status, since it is better to return closed connection than
Expand Down
99 changes: 59 additions & 40 deletions cassandra/tablets.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,84 @@
from bisect import bisect_left
from operator import attrgetter
from threading import Lock
from typing import Optional
from uuid import UUID

# C-accelerated attrgetter avoids per-call lambda allocation overhead
_get_first_token = attrgetter("first_token")
_get_last_token = attrgetter("last_token")


class Tablet(object):
"""
Represents a single ScyllaDB tablet.
It stores information about each replica, its host and shard,
and the token interval in the format (first_token, last_token].
"""
first_token = 0
last_token = 0
replicas = None
__slots__ = ('first_token', 'last_token', 'replicas', '_replica_dict')

def __init__(self, first_token=0, last_token=0, replicas=None):
self.first_token = first_token
self.last_token = last_token
self.replicas = replicas
if replicas is not None:
replicas_tuple = tuple(replicas)
self.replicas = replicas_tuple
self._replica_dict = {r[0]: r[1] for r in replicas_tuple}
else:
self.replicas = None
self._replica_dict = {}

def __str__(self):
return "<Tablet: first_token=%s last_token=%s replicas=%s>" \
% (self.first_token, self.last_token, self.replicas)
__repr__ = __str__

@staticmethod
def _is_valid_tablet(replicas):
return replicas is not None and len(replicas) != 0

@staticmethod
def from_row(first_token, last_token, replicas):
if Tablet._is_valid_tablet(replicas):
tablet = Tablet(first_token, last_token, replicas)
return tablet
return None
if not replicas:
return None
return Tablet(first_token, last_token, replicas)

def replica_contains_host_id(self, uuid: UUID) -> bool:
for replica in self.replicas:
if replica[0] == uuid:
return True
return False
return uuid in self._replica_dict


class Tablets(object):
_lock = None
_tablets = {}
_tablets = {} # (keyspace, table) -> list[Tablet]
_first_tokens = {} # (keyspace, table) -> list[int]
_last_tokens = {} # (keyspace, table) -> list[int]

def __init__(self, tablets):
self._tablets = tablets
# Build parallel token index lists from any pre-populated data
self._first_tokens = {
key: [t.first_token for t in tlist]
for key, tlist in tablets.items()
}
self._last_tokens = {
key: [t.last_token for t in tlist]
for key, tlist in tablets.items()
}
self._lock = Lock()

def table_has_tablets(self, keyspace, table) -> bool:
return bool(self._tablets.get((keyspace, table), []))

def get_tablet_for_key(self, keyspace, table, t):
tablet = self._tablets.get((keyspace, table), [])
if not tablet:
key = (keyspace, table)
last_tokens = self._last_tokens.get(key)
if not last_tokens:
return None

id = bisect_left(tablet, t.value, key=_get_last_token)
if id < len(tablet) and t.value > tablet[id].first_token:
return tablet[id]
token_value = t.value
id = bisect_left(last_tokens, token_value)
if id < len(last_tokens) and token_value > self._first_tokens[key][id]:
return self._tablets[key][id]
return None

def drop_tablets(self, keyspace: str, table: Optional[str] = None):
with self._lock:
if table is not None:
self._tablets.pop((keyspace, table), None)
key = (keyspace, table)
self._tablets.pop(key, None)
self._first_tokens.pop(key, None)
self._last_tokens.pop(key, None)
return

to_be_deleted = []
Expand All @@ -81,36 +88,48 @@ def drop_tablets(self, keyspace: str, table: Optional[str] = None):

for key in to_be_deleted:
del self._tablets[key]
self._first_tokens.pop(key, None)
self._last_tokens.pop(key, None)

def drop_tablets_by_host_id(self, host_id: Optional[UUID]):
if host_id is None:
return
with self._lock:
for key, tablets in self._tablets.items():
to_be_deleted = []
for tablet_id, tablet in enumerate(tablets):
if tablet.replica_contains_host_id(host_id):
to_be_deleted.append(tablet_id)

for tablet_id in reversed(to_be_deleted):
tablets.pop(tablet_id)
# Filter in one pass instead of popping one-by-one (O(n) vs O(k*n))
keep = [i for i, t in enumerate(tablets)
if not t.replica_contains_host_id(host_id)]
if len(keep) == len(tablets):
continue # nothing to drop
self._tablets[key] = [tablets[i] for i in keep]
first = self._first_tokens[key]
last = self._last_tokens[key]
self._first_tokens[key] = [first[i] for i in keep]
self._last_tokens[key] = [last[i] for i in keep]

def add_tablet(self, keyspace, table, tablet):
with self._lock:
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
key = (keyspace, table)
tablets_for_table = self._tablets.setdefault(key, [])
first_tokens = self._first_tokens.setdefault(key, [])
last_tokens = self._last_tokens.setdefault(key, [])

# find first overlapping range
start = bisect_left(tablets_for_table, tablet.first_token, key=_get_first_token)
if start > 0 and tablets_for_table[start - 1].last_token > tablet.first_token:
start = bisect_left(first_tokens, tablet.first_token)
if start > 0 and last_tokens[start - 1] > tablet.first_token:
start = start - 1

# find last overlapping range
end = bisect_left(tablets_for_table, tablet.last_token, key=_get_last_token)
if end < len(tablets_for_table) and tablets_for_table[end].first_token >= tablet.last_token:
end = bisect_left(last_tokens, tablet.last_token)
if end < len(last_tokens) and first_tokens[end] >= tablet.last_token:
end = end - 1

if start <= end:
del tablets_for_table[start:end + 1]
del first_tokens[start:end + 1]
del last_tokens[start:end + 1]

tablets_for_table.insert(start, tablet)
first_tokens.insert(start, tablet.first_token)
last_tokens.insert(start, tablet.last_token)

12 changes: 6 additions & 6 deletions tests/unit/test_response_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_result_message(self):
rf.send_request()

rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY, tablet=ANY)

connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

Expand Down Expand Up @@ -284,7 +284,7 @@ def test_retry_policy_says_retry(self):
rf.send_request()

rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY, tablet=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

result = Mock(spec=UnavailableErrorMessage, info={})
Expand All @@ -303,7 +303,7 @@ def test_retry_policy_says_retry(self):
# it should try again with the same host since this was
# an UnavailableException
rf.session._pools.get.assert_called_with(host)
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY, tablet=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

def test_retry_with_different_host(self):
Expand All @@ -318,7 +318,7 @@ def test_retry_with_different_host(self):
rf.send_request()

rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY, tablet=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
assert ConsistencyLevel.QUORUM == rf.message.consistency_level

Expand All @@ -337,7 +337,7 @@ def test_retry_with_different_host(self):

# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY, tablet=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

# the consistency level should be the same
Expand Down Expand Up @@ -982,7 +982,7 @@ def test_single_host_query_plan_exhausted_after_one_retry(self):

# Verify initial request was sent
rf.session._pools.get.assert_called_once_with(specific_host)
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY)
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY, tablet=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])

# Simulate a ServerError response (which triggers RETRY_NEXT_HOST by default)
Expand Down
Loading
Loading