|
| 1 | +require 'socket' |
| 2 | +require 'digest/md5' |
| 3 | + |
| 4 | +BACKEND_MESSAGE_CODES = { |
| 5 | + 'Z' => "ReadyForQuery", |
| 6 | + 'C' => "CommandComplete", |
| 7 | + 'T' => "RowDescription", |
| 8 | + 'D' => "DataRow", |
| 9 | + '1' => "ParseComplete", |
| 10 | + '2' => "BindComplete", |
| 11 | + 'E' => "ErrorResponse", |
| 12 | + 's' => "PortalSuspended", |
| 13 | +} |
| 14 | + |
| 15 | +class PostgresSocket |
| 16 | + def initialize(host, port) |
| 17 | + @port = port |
| 18 | + @host = host |
| 19 | + @socket = TCPSocket.new @host, @port |
| 20 | + @parameters = {} |
| 21 | + @verbose = true |
| 22 | + end |
| 23 | + |
| 24 | + def send_md5_password_message(username, password, salt) |
| 25 | + m = Digest::MD5.hexdigest(password + username) |
| 26 | + m = Digest::MD5.hexdigest(m + salt.map(&:chr).join("")) |
| 27 | + m = 'md5' + m |
| 28 | + bytes = (m.split("").map(&:ord) + [0]).flatten |
| 29 | + message_size = bytes.count + 4 |
| 30 | + |
| 31 | + message = [] |
| 32 | + |
| 33 | + message << 'p'.ord |
| 34 | + message << [message_size].pack('l>').unpack('CCCC') # 4 |
| 35 | + message << bytes |
| 36 | + message.flatten! |
| 37 | + |
| 38 | + |
| 39 | + @socket.write(message.pack('C*')) |
| 40 | + end |
| 41 | + |
| 42 | + def send_startup_message(username, database, password) |
| 43 | + message = [] |
| 44 | + |
| 45 | + message << [196608].pack('l>').unpack('CCCC') # 4 |
| 46 | + message << "user".split('').map(&:ord) # 4, 8 |
| 47 | + message << 0 # 1, 9 |
| 48 | + message << username.split('').map(&:ord) # 2, 11 |
| 49 | + message << 0 # 1, 12 |
| 50 | + message << "database".split('').map(&:ord) # 8, 20 |
| 51 | + message << 0 # 1, 21 |
| 52 | + message << database.split('').map(&:ord) # 2, 23 |
| 53 | + message << 0 # 1, 24 |
| 54 | + message << 0 # 1, 25 |
| 55 | + message.flatten! |
| 56 | + |
| 57 | + total_message_size = message.size + 4 |
| 58 | + |
| 59 | + message_len = [total_message_size].pack('l>').unpack('CCCC') |
| 60 | + |
| 61 | + @socket.write([message_len + message].flatten.pack('C*')) |
| 62 | + |
| 63 | + sleep 0.1 |
| 64 | + |
| 65 | + read_startup_response(username, password) |
| 66 | + end |
| 67 | + |
| 68 | + def read_startup_response(username, password) |
| 69 | + message_code, message_len = @socket.recv(5).unpack("al>") |
| 70 | + while message_code == 'R' |
| 71 | + auth_code = @socket.recv(4).unpack('l>').pop |
| 72 | + case auth_code |
| 73 | + when 5 # md5 |
| 74 | + salt = @socket.recv(4).unpack('CCCC') |
| 75 | + send_md5_password_message(username, password, salt) |
| 76 | + message_code, message_len = @socket.recv(5).unpack("al>") |
| 77 | + when 0 # trust |
| 78 | + break |
| 79 | + end |
| 80 | + end |
| 81 | + loop do |
| 82 | + message_code, message_len = @socket.recv(5).unpack("al>") |
| 83 | + if message_code == 'Z' |
| 84 | + @socket.recv(1).unpack("a") # most likely I |
| 85 | + break # We are good to go |
| 86 | + end |
| 87 | + if message_code == 'S' |
| 88 | + actual_message = @socket.recv(message_len - 4).unpack("C*") |
| 89 | + k,v = actual_message.pack('U*').split(/\x00/) |
| 90 | + @parameters[k] = v |
| 91 | + end |
| 92 | + if message_code == 'K' |
| 93 | + process_id, secret_key = @socket.recv(message_len - 4).unpack("l>l>") |
| 94 | + @parameters["process_id"] = process_id |
| 95 | + @parameters["secret_key"] = secret_key |
| 96 | + end |
| 97 | + end |
| 98 | + return @parameters |
| 99 | + end |
| 100 | + |
| 101 | + def cancel_query |
| 102 | + socket = TCPSocket.new @host, @port |
| 103 | + process_key = @parameters["process_id"] |
| 104 | + secret_key = @parameters["secret_key"] |
| 105 | + message = [] |
| 106 | + message << [16].pack('l>').unpack('CCCC') # 4 |
| 107 | + message << [80877102].pack('l>').unpack('CCCC') # 4 |
| 108 | + message << [process_key.to_i].pack('l>').unpack('CCCC') # 4 |
| 109 | + message << [secret_key.to_i].pack('l>').unpack('CCCC') # 4 |
| 110 | + message.flatten! |
| 111 | + socket.write(message.flatten.pack('C*')) |
| 112 | + socket.close |
| 113 | + log "[F] Sent CancelRequest message" |
| 114 | + end |
| 115 | + |
| 116 | + def send_query_message(query) |
| 117 | + query_size = query.length |
| 118 | + message_size = 1 + 4 + query_size |
| 119 | + message = [] |
| 120 | + message << "Q".ord |
| 121 | + message << [message_size].pack('l>').unpack('CCCC') # 4 |
| 122 | + message << query.split('').map(&:ord) # 2, 11 |
| 123 | + message << 0 # 1, 12 |
| 124 | + message.flatten! |
| 125 | + @socket.write(message.flatten.pack('C*')) |
| 126 | + log "[F] Sent Q message (#{query})" |
| 127 | + end |
| 128 | + |
| 129 | + def send_parse_message(query) |
| 130 | + query_size = query.length |
| 131 | + message_size = 2 + 2 + 4 + query_size |
| 132 | + message = [] |
| 133 | + message << "P".ord |
| 134 | + message << [message_size].pack('l>').unpack('CCCC') # 4 |
| 135 | + message << 0 # unnamed statement |
| 136 | + message << query.split('').map(&:ord) # 2, 11 |
| 137 | + message << 0 # 1, 12 |
| 138 | + message << [0, 0] |
| 139 | + message.flatten! |
| 140 | + @socket.write(message.flatten.pack('C*')) |
| 141 | + log "[F] Sent P message (#{query})" |
| 142 | + end |
| 143 | + |
| 144 | + def send_bind_message |
| 145 | + message = [] |
| 146 | + message << "B".ord |
| 147 | + message << [12].pack('l>').unpack('CCCC') # 4 |
| 148 | + message << 0 # unnamed statement |
| 149 | + message << 0 # unnamed statement |
| 150 | + message << [0, 0] # 2 |
| 151 | + message << [0, 0] # 2 |
| 152 | + message << [0, 0] # 2 |
| 153 | + message.flatten! |
| 154 | + @socket.write(message.flatten.pack('C*')) |
| 155 | + log "[F] Sent B message" |
| 156 | + end |
| 157 | + |
| 158 | + def send_describe_message(mode) |
| 159 | + message = [] |
| 160 | + message << "D".ord |
| 161 | + message << [6].pack('l>').unpack('CCCC') # 4 |
| 162 | + message << mode.ord |
| 163 | + message << 0 # unnamed statement |
| 164 | + message.flatten! |
| 165 | + @socket.write(message.flatten.pack('C*')) |
| 166 | + log "[F] Sent D message" |
| 167 | + end |
| 168 | + |
| 169 | + def send_execute_message(limit=0) |
| 170 | + message = [] |
| 171 | + message << "E".ord |
| 172 | + message << [9].pack('l>').unpack('CCCC') # 4 |
| 173 | + message << 0 # unnamed statement |
| 174 | + message << [limit].pack('l>').unpack('CCCC') # 4 |
| 175 | + message.flatten! |
| 176 | + @socket.write(message.flatten.pack('C*')) |
| 177 | + log "[F] Sent E message" |
| 178 | + end |
| 179 | + |
| 180 | + def send_sync_message |
| 181 | + message = [] |
| 182 | + message << "S".ord |
| 183 | + message << [4].pack('l>').unpack('CCCC') # 4 |
| 184 | + message.flatten! |
| 185 | + @socket.write(message.flatten.pack('C*')) |
| 186 | + log "[F] Sent S message" |
| 187 | + end |
| 188 | + |
| 189 | + def send_copydone_message |
| 190 | + message = [] |
| 191 | + message << "c".ord |
| 192 | + message << [4].pack('l>').unpack('CCCC') # 4 |
| 193 | + message.flatten! |
| 194 | + @socket.write(message.flatten.pack('C*')) |
| 195 | + log "[F] Sent c message" |
| 196 | + end |
| 197 | + |
| 198 | + def send_copyfail_message |
| 199 | + message = [] |
| 200 | + message << "f".ord |
| 201 | + message << [5].pack('l>').unpack('CCCC') # 4 |
| 202 | + message << 0 |
| 203 | + message.flatten! |
| 204 | + @socket.write(message.flatten.pack('C*')) |
| 205 | + log "[F] Sent f message" |
| 206 | + end |
| 207 | + |
| 208 | + def send_flush_message |
| 209 | + message = [] |
| 210 | + message << "H".ord |
| 211 | + message << [4].pack('l>').unpack('CCCC') # 4 |
| 212 | + message.flatten! |
| 213 | + @socket.write(message.flatten.pack('C*')) |
| 214 | + log "[F] Sent H message" |
| 215 | + end |
| 216 | + |
| 217 | + def read_from_server() |
| 218 | + output_messages = [] |
| 219 | + retry_count = 0 |
| 220 | + message_code = nil |
| 221 | + message_len = 0 |
| 222 | + loop do |
| 223 | + begin |
| 224 | + message_code, message_len = @socket.recv_nonblock(5).unpack("al>") |
| 225 | + rescue IO::WaitReadable |
| 226 | + return output_messages if retry_count > 50 |
| 227 | + |
| 228 | + retry_count += 1 |
| 229 | + sleep(0.01) |
| 230 | + next |
| 231 | + end |
| 232 | + message = { |
| 233 | + code: message_code, |
| 234 | + len: message_len, |
| 235 | + bytes: [] |
| 236 | + } |
| 237 | + log "[B] #{BACKEND_MESSAGE_CODES[message_code] || ('UnknownMessage(' + message_code + ')')}" |
| 238 | + |
| 239 | + actual_message_length = message_len - 4 |
| 240 | + if actual_message_length > 0 |
| 241 | + message[:bytes] = @socket.recv(message_len - 4).unpack("C*") |
| 242 | + log "\t#{message[:bytes].join(",")}" |
| 243 | + log "\t#{message[:bytes].map(&:chr).join(" ")}" |
| 244 | + end |
| 245 | + output_messages << message |
| 246 | + return output_messages if message_code == 'Z' |
| 247 | + end |
| 248 | + end |
| 249 | + |
| 250 | + def log(msg) |
| 251 | + return unless @verbose |
| 252 | + |
| 253 | + puts msg |
| 254 | + end |
| 255 | + |
| 256 | + def close |
| 257 | + @socket.close |
| 258 | + end |
| 259 | +end |
0 commit comments