diff --git a/src/query_router.rs b/src/query_router.rs index 9d7a106a..c39b1b49 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -136,206 +136,206 @@ impl QueryRouter { &self.pool_settings } - /// Try to parse a command and execute it. - pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { - let mut message_cursor = Cursor::new(message_buffer); - - let code = message_cursor.get_u8() as char; - let len = message_cursor.get_i32() as usize; - + fn routing_shard_base_on_regex(&mut self, len: usize, message_buffer: &BytesMut) { let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some() || self.pool_settings.sharding_key_regex.is_some(); - // Check for any sharding regex matches in any queries - if comment_shard_routing_enabled { - match code as char { - // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement - 'P' | 'Q' => { - // Check only the first block of bytes configured by the pool settings - let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit); - - let query_start_index = mem::size_of::() + mem::size_of::(); - - let initial_segment = String::from_utf8_lossy( - &message_buffer[query_start_index..query_start_index + seg], - ); - - // Check for a shard_id included in the query - if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex { - let shard_id = shard_id_regex.captures(&initial_segment).and_then(|cap| { - cap.get(1).and_then(|id| id.as_str().parse::().ok()) - }); - if let Some(shard_id) = shard_id { - debug!("Setting shard to {:?}", shard_id); - self.set_shard(Some(shard_id)); - // Skip other command processing since a sharding command was found - return None; - } - } - - // Check for a sharding_key included in the query - if let Some(sharding_key_regex) = &self.pool_settings.sharding_key_regex { - let sharding_key = - sharding_key_regex - .captures(&initial_segment) - .and_then(|cap| { - cap.get(1).and_then(|id| id.as_str().parse::().ok()) - }); - if let Some(sharding_key) = sharding_key { - debug!("Setting sharding_key to {:?}", sharding_key); - self.set_sharding_key(sharding_key); - // Skip other command processing since a sharding command was found - return None; - } - } - } - _ => {} + if !comment_shard_routing_enabled { + return; + } + // Check only the first block of bytes configured by the pool settings + let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit); + + let query_start_index = mem::size_of::() + mem::size_of::(); + + let initial_segment = + String::from_utf8_lossy(&message_buffer[query_start_index..query_start_index + seg]); + + // Check for a shard_id included in the query + if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex { + let shard_id = shard_id_regex + .captures(&initial_segment) + .and_then(|cap| cap.get(1).and_then(|id| id.as_str().parse::().ok())); + if let Some(shard_id) = shard_id { + debug!("Setting shard to {:?}", shard_id); + self.set_shard(Some(shard_id)); + return; } } - // Only simple protocol supported for commands processed below - if code != 'Q' { - return None; + // Check for a sharding_key included in the query + if let Some(sharding_key_regex) = &self.pool_settings.sharding_key_regex { + let sharding_key = sharding_key_regex + .captures(&initial_segment) + .and_then(|cap| cap.get(1).and_then(|id| id.as_str().parse::().ok())); + if let Some(sharding_key) = sharding_key { + debug!("Setting sharding_key to {:?}", sharding_key); + self.set_sharding_key(sharding_key); + return; + } } + } - let query = message_cursor.read_string().unwrap(); + /// Try to parse a command and execute it. + pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { + let mut message_cursor = Cursor::new(message_buffer); - let regex_set = match CUSTOM_SQL_REGEX_SET.get() { - Some(regex_set) => regex_set, - None => return None, - }; + let code = message_cursor.get_u8() as char; + let len = message_cursor.get_i32() as usize; - let regex_list = match CUSTOM_SQL_REGEX_LIST.get() { - Some(regex_list) => regex_list, - None => return None, - }; + match code as char { + // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement + 'P' => { + self.routing_shard_base_on_regex(len, message_buffer); + None + } + 'Q' => { + let query = message_cursor.read_string().unwrap(); - let matches: Vec<_> = regex_set.matches(&query).into_iter().collect(); + let regex_set = match CUSTOM_SQL_REGEX_SET.get() { + Some(regex_set) => regex_set, + None => return None, + }; - // This is not a custom query, try to infer which - // server it'll go to if the query parser is enabled. - if matches.len() != 1 { - debug!("Regular query, not a command"); - return None; - } + let regex_list = match CUSTOM_SQL_REGEX_LIST.get() { + Some(regex_list) => regex_list, + None => return None, + }; - let command = match matches[0] { - 0 => Command::SetShardingKey, - 1 => Command::SetShard, - 2 => Command::ShowShard, - 3 => Command::SetServerRole, - 4 => Command::ShowServerRole, - 5 => Command::SetPrimaryReads, - 6 => Command::ShowPrimaryReads, - _ => unreachable!(), - }; + let matches: Vec<_> = regex_set.matches(&query).into_iter().collect(); - let mut value = match command { - Command::SetShardingKey - | Command::SetShard - | Command::SetServerRole - | Command::SetPrimaryReads => { - // Capture value. I know this re-runs the regex engine, but I haven't - // figured out a better way just yet. I think I can write a single Regex - // that matches all 5 custom SQL patterns, but maybe that's not very legible? - // - // I think this is faster than running the Regex engine 5 times. - match regex_list[matches[0]].captures(&query) { - Some(captures) => match captures.get(1) { - Some(value) => value.as_str().to_string(), - None => return None, - }, - None => return None, + // This is not a custom query, try to infer which + // server it'll go to if the query parser is enabled. + if matches.len() != 1 { + self.routing_shard_base_on_regex(len, message_buffer); + debug!("Regular query, not a command"); + return None; } - } - Command::ShowShard => self - .shard() - .map_or_else(|| "unset".to_string(), |x| x.to_string()), - Command::ShowServerRole => match self.active_role { - Some(Role::Primary) => Role::Primary.to_string(), - Some(Role::Replica) => Role::Replica.to_string(), - Some(Role::Mirror) => Role::Mirror.to_string(), - None => { - if self.query_parser_enabled() { - String::from("auto") - } else { - String::from("any") - } - } - }, + let command = match matches[0] { + 0 => Command::SetShardingKey, + 1 => Command::SetShard, + 2 => Command::ShowShard, + 3 => Command::SetServerRole, + 4 => Command::ShowServerRole, + 5 => Command::SetPrimaryReads, + 6 => Command::ShowPrimaryReads, + _ => unreachable!(), + }; - Command::ShowPrimaryReads => match self.primary_reads_enabled() { - true => String::from("on"), - false => String::from("off"), - }, - }; + let mut value = match command { + Command::SetShardingKey + | Command::SetShard + | Command::SetServerRole + | Command::SetPrimaryReads => { + // Capture value. I know this re-runs the regex engine, but I haven't + // figured out a better way just yet. I think I can write a single Regex + // that matches all 5 custom SQL patterns, but maybe that's not very legible? + // + // I think this is faster than running the Regex engine 5 times. + match regex_list[matches[0]].captures(&query) { + Some(captures) => match captures.get(1) { + Some(value) => value.as_str().to_string(), + None => return None, + }, + None => return None, + } + } - match command { - Command::SetShardingKey => { - // TODO: some error handling here - value = self - .set_sharding_key(value.parse::().unwrap()) - .unwrap() - .to_string(); - } + Command::ShowShard => self + .shard() + .map_or_else(|| "unset".to_string(), |x| x.to_string()), + Command::ShowServerRole => match self.active_role { + Some(Role::Primary) => Role::Primary.to_string(), + Some(Role::Replica) => Role::Replica.to_string(), + Some(Role::Mirror) => Role::Mirror.to_string(), + None => { + if self.query_parser_enabled() { + String::from("auto") + } else { + String::from("any") + } + } + }, - Command::SetShard => { - self.active_shard = match value.to_ascii_uppercase().as_ref() { - "ANY" => Some(rand::random::() % self.pool_settings.shards), - _ => Some(value.parse::().unwrap()), + Command::ShowPrimaryReads => match self.primary_reads_enabled() { + true => String::from("on"), + false => String::from("off"), + }, }; - } - Command::SetServerRole => { - self.active_role = match value.to_ascii_lowercase().as_ref() { - "primary" => { - self.query_parser_enabled = Some(false); - Some(Role::Primary) + match command { + Command::SetShardingKey => { + // TODO: some error handling here + value = self + .set_sharding_key(value.parse::().unwrap()) + .unwrap() + .to_string(); + // Since we are processing a set sharding key command bypass routing_shard_base_on_regex + return Some((command, value)); } - "replica" => { - self.query_parser_enabled = Some(false); - Some(Role::Replica) + Command::SetShard => { + self.active_shard = match value.to_ascii_uppercase().as_ref() { + "ANY" => Some(rand::random::() % self.pool_settings.shards), + _ => Some(value.parse::().unwrap()), + }; + // Since we are processing a set shard command bypass routing_shard_base_on_regex + return Some((command, value)); } - "any" => { - self.query_parser_enabled = Some(false); - None - } + Command::SetServerRole => { + self.active_role = match value.to_ascii_lowercase().as_ref() { + "primary" => { + self.query_parser_enabled = Some(false); + Some(Role::Primary) + } - "auto" => { - self.query_parser_enabled = Some(true); - None - } + "replica" => { + self.query_parser_enabled = Some(false); + Some(Role::Replica) + } - "default" => { - self.active_role = self.pool_settings.default_role; - self.query_parser_enabled = None; - self.active_role + "any" => { + self.query_parser_enabled = Some(false); + None + } + + "auto" => { + self.query_parser_enabled = Some(true); + None + } + + "default" => { + self.active_role = self.pool_settings.default_role; + self.query_parser_enabled = None; + self.active_role + } + + _ => unreachable!(), + }; } - _ => unreachable!(), - }; - } + Command::SetPrimaryReads => { + if value == "on" { + debug!("Setting primary reads to on"); + self.primary_reads_enabled = Some(true); + } else if value == "off" { + debug!("Setting primary reads to off"); + self.primary_reads_enabled = Some(false); + } else if value == "default" { + debug!("Setting primary reads to default"); + self.primary_reads_enabled = None; + } + } - Command::SetPrimaryReads => { - if value == "on" { - debug!("Setting primary reads to on"); - self.primary_reads_enabled = Some(true); - } else if value == "off" { - debug!("Setting primary reads to off"); - self.primary_reads_enabled = Some(false); - } else if value == "default" { - debug!("Setting primary reads to default"); - self.primary_reads_enabled = None; + _ => (), } + self.routing_shard_base_on_regex(len, message_buffer); + Some((command, value)) } - - _ => (), + _ => None, } - - Some((command, value)) } pub fn parse(&self, message: &BytesMut) -> Result, Error> { @@ -1157,6 +1157,26 @@ mod test { )) ); } + + // Test special command parsing correctly after setting sharding regexes + let mut qr_with_regex = QueryRouter { + active_shard: None, + active_role: None, + query_parser_enabled: None, + primary_reads_enabled: None, + pool_settings: PoolSettings { + sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()), + shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()), + ..Default::default() + }, + placeholders: Vec::new(), + }; + let set_shard_query = simple_query("SET SHARD TO '1'"); + assert_eq!( + qr_with_regex.try_execute_command(&set_shard_query), + Some((Command::SetShard, String::from("1"))) + ); + assert_eq!(Some(1), qr_with_regex.active_shard); } #[test]