File: //usr/local/rvm/gems/ruby-2.7.4/gems/ruby-mysql-2.11.0/lib/mysql/protocol.rb
# coding: ascii-8bit
# Copyright (C) 2008 TOMITA Masahiro
# mailto:tommy@tmtm.org
require "socket"
require "stringio"
require "openssl"
require_relative 'authenticator.rb'
class Mysql
# MySQL network protocol
class Protocol
VERSION = 10
MAX_PACKET_LENGTH = 2**24-1
# Convert netdata to Ruby value
# === Argument
# data :: [Packet] packet data
# type :: [Integer] field type
# unsigned :: [true or false] true if value is unsigned
# === Return
# Object :: converted value.
def self.net2value(pkt, type, unsigned)
case type
when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
return pkt.lcs
when Field::TYPE_TINY
v = pkt.utiny
return unsigned ? v : v < 128 ? v : v-256
when Field::TYPE_SHORT
v = pkt.ushort
return unsigned ? v : v < 32768 ? v : v-65536
when Field::TYPE_INT24, Field::TYPE_LONG
v = pkt.ulong
return unsigned ? v : v < 0x8000_0000 ? v : v-0x10000_0000
when Field::TYPE_LONGLONG
n1, n2 = pkt.ulong, pkt.ulong
v = (n2 << 32) | n1
return unsigned ? v : v < 0x8000_0000_0000_0000 ? v : v-0x10000_0000_0000_0000
when Field::TYPE_FLOAT
return pkt.read(4).unpack('e').first
when Field::TYPE_DOUBLE
return pkt.read(8).unpack('E').first
when Field::TYPE_DATE
len = pkt.utiny
y, m, d = pkt.read(len).unpack("vCC")
t = Mysql::Time.new(y, m, d, nil, nil, nil)
return t
when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
len = pkt.utiny
y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
return Mysql::Time.new(y, m, d, h, mi, s, false, sp)
when Field::TYPE_TIME
len = pkt.utiny
sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
h = d.to_i * 24 + h.to_i
return Mysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
when Field::TYPE_YEAR
return pkt.ushort
when Field::TYPE_BIT
return pkt.lcs
else
raise "not implemented: type=#{type}"
end
end
# convert Ruby value to netdata
# === Argument
# v :: [Object] Ruby value.
# === Return
# Integer :: type of column. Field::TYPE_*
# String :: netdata
# === Exception
# ProtocolError :: value too large / value is not supported
def self.value2net(v)
case v
when nil
type = Field::TYPE_NULL
val = ""
when Integer
if -0x8000_0000 <= v && v < 0x8000_0000
type = Field::TYPE_LONG
val = [v].pack('V')
elsif -0x8000_0000_0000_0000 <= v && v < 0x8000_0000_0000_0000
type = Field::TYPE_LONGLONG
val = [v&0xffffffff, v>>32].pack("VV")
elsif 0x8000_0000_0000_0000 <= v && v <= 0xffff_ffff_ffff_ffff
type = Field::TYPE_LONGLONG | 0x8000
val = [v&0xffffffff, v>>32].pack("VV")
else
raise ProtocolError, "value too large: #{v}"
end
when Float
type = Field::TYPE_DOUBLE
val = [v].pack("E")
when String
type = Field::TYPE_STRING
val = Packet.lcs(v)
when ::Time
type = Field::TYPE_DATETIME
val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
when Mysql::Time
type = Field::TYPE_DATETIME
val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.second_part].pack("CvCCCCCV")
else
raise ProtocolError, "class #{v.class} is not supported"
end
return type, val
end
attr_reader :server_info
attr_reader :server_version
attr_reader :thread_id
attr_reader :client_flags
attr_reader :sqlstate
attr_reader :affected_rows
attr_reader :insert_id
attr_reader :server_status
attr_reader :warning_count
attr_reader :message
attr_reader :get_server_public_key
attr_accessor :charset
# @state variable keep state for connection.
# :INIT :: Initial state.
# :READY :: Ready for command.
# :FIELD :: After query(). retr_fields() is needed.
# :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
# make socket connection to server.
# @param host [String] if "localhost" or "" or nil then use UNIX socket. Otherwise use TCP socket
# @param port [Integer] port number using by TCP socket
# @param socket [String] socket file name using by UNIX socket
# @param [Hash] opts
# @option opts :conn_timeout [Integer] connect timeout (sec).
# @option opts :read_timeout [Integer] read timeout (sec).
# @option opts :write_timeout [Integer] write timeout (sec).
# @option opts :local_infile [String] local infile path
# @option opts :get_server_public_key [Boolean]
# @raise [ClientError] connection timeout
def initialize(host, port, socket, opts)
@opts = opts
@insert_id = 0
@warning_count = 0
@gc_stmt_queue = [] # stmt id list which GC destroy.
set_state :INIT
@get_server_public_key = @opts[:get_server_public_key]
begin
if host.nil? or host.empty? or host == "localhost"
socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
@socket = Socket.unix(socket)
else
port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
@socket = Socket.tcp(host, port, connect_timeout: @opts[:connect_timeout])
end
rescue Errno::ETIMEDOUT
raise ClientError, "connection timeout"
end
end
def close
@socket.close
end
# initial negotiate and authenticate.
# === Argument
# user :: [String / nil] username
# passwd :: [String / nil] password
# db :: [String / nil] default database name. nil: no default.
# flag :: [Integer] client flag
# charset :: [Mysql::Charset / nil] charset for connection. nil: use server's charset
# === Exception
# ProtocolError :: The old style password is not supported
def authenticate(user, passwd, db, flag, charset)
check_state :INIT
@authinfo = [user, passwd, db, flag, charset]
reset
init_packet = InitialPacket.parse read
@server_info = init_packet.server_version
@server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
@server_capabilities = init_packet.server_capabilities
@thread_id = init_packet.thread_id
@client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
@client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile]
@client_flags |= CLIENT_CONNECT_WITH_DB if db
@client_flags |= flag
@charset = charset
unless @charset
@charset = Charset.by_number(init_packet.server_charset)
@charset.encoding # raise error if unsupported charset
end
enable_ssl
Authenticator.new(self).authenticate(user, passwd, db, init_packet.scramble_buff, init_packet.auth_plugin)
set_state :READY
end
def enable_ssl
case @opts[:ssl_mode]
when SSL_MODE_DISABLED
return
when SSL_MODE_PREFERRED
return if @socket.local_address.unix?
return if @server_capabilities & CLIENT_SSL == 0
when SSL_MODE_REQUIRED
if @server_capabilities & CLIENT_SSL == 0
raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
end
else
raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
end
begin
@client_flags |= CLIENT_SSL
write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
@socket = OpenSSL::SSL::SSLSocket.new(@socket)
@socket.sync_close = true
@socket.connect
rescue => e
@client_flags &= ~CLIENT_SSL
return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
raise e
end
end
# Quit command
def quit_command
synchronize do
reset
write [COM_QUIT].pack("C")
close
end
end
# Query command
# === Argument
# query :: [String] query string
# === Return
# [Integer / nil] number of fields of results. nil if no results.
def query_command(query)
check_state :READY
begin
reset
write [COM_QUERY, @charset.convert(query)].pack("Ca*")
get_result
rescue
set_state :READY
raise
end
end
# get result of query.
# === Return
# [integer / nil] number of fields of results. nil if no results.
def get_result
begin
res_packet = ResultPacket.parse read
if res_packet.field_count.to_i > 0 # result data exists
set_state :FIELD
return res_packet.field_count
end
if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
send_local_file(res_packet.message)
end
@affected_rows, @insert_id, @server_status, @warning_count, @message =
res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
set_state :READY
return nil
rescue
set_state :READY
raise
end
end
# send local file to server
def send_local_file(filename)
filename = File.absolute_path(filename)
if filename.start_with? @opts[:local_infile]
File.open(filename){|f| write f}
else
raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
end
ensure
write nil # EOF mark
read
end
# Retrieve n fields
# === Argument
# n :: [Integer] number of fields
# === Return
# [Array of Mysql::Field] field list
def retr_fields(n)
check_state :FIELD
begin
fields = n.times.map{Field.new FieldPacket.parse(read)}
read_eof_packet
set_state :RESULT
fields
rescue
set_state :READY
raise
end
end
# Retrieve all records for simple query
# === Argument
# fields :: [Array<Mysql::Field>] number of fields
# === Return
# [Array of Array of String] all records
def retr_all_records(fields)
check_state :RESULT
enc = charset.encoding
begin
all_recs = []
until (pkt = read).eof?
all_recs.push RawRecord.new(pkt, fields, enc)
end
pkt.read(3)
@server_status = pkt.utiny
all_recs
ensure
set_state :READY
end
end
# Field list command
# === Argument
# table :: [String] table name.
# field :: [String / nil] field name that may contain wild card.
# === Return
# [Array of Field] field list
def field_list_command(table, field)
synchronize do
reset
write [COM_FIELD_LIST, table, 0, field].pack("Ca*Ca*")
fields = []
until (data = read).eof?
fields.push Field.new(FieldPacket.parse(data))
end
return fields
end
end
# Process info command
# === Return
# [Array of Field] field list
def process_info_command
check_state :READY
begin
reset
write [COM_PROCESS_INFO].pack("C")
field_count = read.lcb
fields = field_count.times.map{Field.new FieldPacket.parse(read)}
read_eof_packet
set_state :RESULT
return fields
rescue
set_state :READY
raise
end
end
# Ping command
def ping_command
simple_command [COM_PING].pack("C")
end
# Kill command
def kill_command(pid)
simple_command [COM_PROCESS_KILL, pid].pack("CV")
end
# Refresh command
def refresh_command(op)
simple_command [COM_REFRESH, op].pack("CC")
end
# Set option command
def set_option_command(opt)
simple_command [COM_SET_OPTION, opt].pack("Cv")
end
# Shutdown command
def shutdown_command(level)
simple_command [COM_SHUTDOWN, level].pack("CC")
end
# Statistics command
def statistics_command
simple_command [COM_STATISTICS].pack("C")
end
# Stmt prepare command
# === Argument
# stmt :: [String] prepared statement
# === Return
# [Integer] statement id
# [Integer] number of parameters
# [Array of Field] field list
def stmt_prepare_command(stmt)
synchronize do
reset
write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
res_packet = PrepareResultPacket.parse read
if res_packet.param_count > 0
res_packet.param_count.times{read} # skip parameter packet
read_eof_packet
end
if res_packet.field_count > 0
fields = res_packet.field_count.times.map{Field.new FieldPacket.parse(read)}
read_eof_packet
else
fields = []
end
return res_packet.statement_id, res_packet.param_count, fields
end
end
# Stmt execute command
# === Argument
# stmt_id :: [Integer] statement id
# values :: [Array] parameters
# === Return
# [Integer] number of fields
def stmt_execute_command(stmt_id, values)
check_state :READY
begin
reset
write ExecutePacket.serialize(stmt_id, Mysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)
get_result
rescue
set_state :READY
raise
end
end
# Retrieve all records for prepared statement
# === Argument
# fields :: [Array of Mysql::Fields] field list
# charset :: [Mysql::Charset]
# === Return
# [Array of Array of Object] all records
def stmt_retr_all_records(fields, charset)
check_state :RESULT
enc = charset.encoding
begin
all_recs = []
until (pkt = read).eof?
all_recs.push StmtRawRecord.new(pkt, fields, enc)
end
all_recs
ensure
set_state :READY
end
end
# Stmt close command
# === Argument
# stmt_id :: [Integer] statement id
def stmt_close_command(stmt_id)
synchronize do
reset
write [COM_STMT_CLOSE, stmt_id].pack("CV")
end
end
def gc_stmt(stmt_id)
@gc_stmt_queue.push stmt_id
end
def check_state(st)
raise 'command out of sync' unless @state == st
end
def set_state(st)
@state = st
if st == :READY && !@gc_stmt_queue.empty?
gc_disabled = GC.disable
begin
while st = @gc_stmt_queue.shift
reset
write [COM_STMT_CLOSE, st].pack("CV")
end
ensure
GC.enable unless gc_disabled
end
end
end
def synchronize
begin
check_state :READY
return yield
ensure
set_state :READY
end
end
# Reset sequence number
def reset
@seq = 0 # packet counter. reset by each command
end
# Read one packet data
# === Return
# [Packet] packet data
# === Exception
# [ProtocolError] invalid packet sequence number
def read
data = ''
len = nil
begin
header = read_timeout(4, @opts[:read_timeout])
raise EOFError unless header && header.length == 4
len1, len2, seq = header.unpack("CvC")
len = (len2 << 8) + len1
raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
@seq = (@seq + 1) % 256
ret = read_timeout(len, @opts[:read_timeout])
raise EOFError unless ret && ret.length == len
data.concat ret
rescue EOFError
raise ClientError::ServerGoneError, 'MySQL server has gone away'
rescue Errno::ETIMEDOUT
raise ClientError, "read timeout"
end while len == MAX_PACKET_LENGTH
@sqlstate = "00000"
# Error packet
if data[0] == ?\xff
_, errno, marker, @sqlstate, message = data.unpack("Cvaa5a*")
unless marker == "#"
_, errno, message = data.unpack("Cva*") # Version 4.0 Error
@sqlstate = ""
end
message.force_encoding(@charset.encoding)
if Mysql::ServerError::ERROR_MAP.key? errno
raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
end
raise Mysql::ServerError.new(message, @sqlstate, errno)
end
Packet.new(data)
end
def read_timeout(len, timeout)
return @socket.read(len) if timeout.nil? || timeout == 0
result = ''
e = ::Time.now + timeout
while result.size < len
now = ::Time.now
raise Errno::ETIMEDOUT if now > e
r = @socket.read_nonblock(len - result.size, exception: false)
case r
when :wait_readable
IO.select([@socket], nil, nil, e - now)
next
when :wait_writable
IO.select(nil, [@socket], nil, e - now)
next
else
result << r
end
end
return result
end
# Write one packet data
# === Argument
# data :: [String / IO] packet data. If data is nil, write empty packet.
def write(data)
begin
@socket.sync = false
if data.nil?
write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
@seq = (@seq + 1) % 256
else
data = StringIO.new data if data.is_a? String
while d = data.read(MAX_PACKET_LENGTH)
write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
@seq = (@seq + 1) % 256
end
end
@socket.sync = true
@socket.flush
rescue Errno::EPIPE
raise ClientError::ServerGoneError, 'MySQL server has gone away'
rescue Errno::ETIMEDOUT
raise ClientError, "write timeout"
end
end
def write_timeout(data, timeout)
return @socket.write(data) if timeout.nil? || timeout == 0
len = 0
e = ::Time.now + timeout
while len < data.size
now = ::Time.now
raise Errno::ETIMEDOUT if now > e
l = @socket.write_nonblock(data[len..-1], exception: false)
case l
when :wait_readable
IO.select([@socket], nil, nil, e - now)
when :wait_writable
IO.select(nil, [@socket], nil, e - now)
else
len += l
end
end
return len
end
# Read EOF packet
# === Exception
# [ProtocolError] packet is not EOF
def read_eof_packet
raise ProtocolError, "packet is not EOF" unless read.eof?
end
# Send simple command
# === Argument
# packet :: [String] packet data
# === Return
# [String] received data
def simple_command(packet)
synchronize do
reset
write packet
read.to_s
end
end
# Initial packet
class InitialPacket
def self.parse(pkt)
protocol_version = pkt.utiny
server_version = pkt.string
thread_id = pkt.ulong
scramble_buff = pkt.read(8)
f0 = pkt.utiny
server_capabilities = pkt.ushort
server_charset = pkt.utiny
server_status = pkt.ushort
server_capabilities2 = pkt.ushort
scramble_length = pkt.utiny
_f1 = pkt.read(10)
rest_scramble_buff = pkt.string
auth_plugin = pkt.string
server_capabilities |= server_capabilities2 << 16
scramble_buff.concat rest_scramble_buff
raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
raise ProtocolError, "invalid packet: scramble_length(#{scramble_length}) != length of scramble(#{scramble_buff.size + 1})" unless scramble_length == scramble_buff.size + 1
self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff, auth_plugin
end
attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff, :auth_plugin
def initialize(*args)
@protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff, @auth_plugin = args
end
end
# Result packet
class ResultPacket
def self.parse(pkt)
field_count = pkt.lcb
if field_count == 0
affected_rows = pkt.lcb
insert_id = pkt.lcb
server_status = pkt.ushort
warning_count = pkt.ushort
message = pkt.lcs
return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
elsif field_count.nil? # LOAD DATA LOCAL INFILE
return self.new(nil, nil, nil, nil, nil, pkt.to_s)
else
return self.new(field_count)
end
end
attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
def initialize(*args)
@field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
end
end
# Field packet
class FieldPacket
def self.parse(pkt)
_first = pkt.lcs
db = pkt.lcs
table = pkt.lcs
org_table = pkt.lcs
name = pkt.lcs
org_name = pkt.lcs
_f0 = pkt.utiny
charsetnr = pkt.ushort
length = pkt.ulong
type = pkt.utiny
flags = pkt.ushort
decimals = pkt.utiny
f1 = pkt.ushort
raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0
default = pkt.lcs
return self.new(db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default)
end
attr_reader :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default
def initialize(*args)
@db, @table, @org_table, @name, @org_name, @charsetnr, @length, @type, @flags, @decimals, @default = args
end
end
# Prepare result packet
class PrepareResultPacket
def self.parse(pkt)
raise ProtocolError, "invalid packet" unless pkt.utiny == 0
statement_id = pkt.ulong
field_count = pkt.ushort
param_count = pkt.ushort
f = pkt.utiny
warning_count = pkt.ushort
raise ProtocolError, "invalid packet" unless f == 0x00
self.new statement_id, field_count, param_count, warning_count
end
attr_reader :statement_id, :field_count, :param_count, :warning_count
def initialize(*args)
@statement_id, @field_count, @param_count, @warning_count = args
end
end
# Authentication packet
class AuthenticationPacket
def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
data = [
client_flags,
max_packet_size,
charset_number,
"", # always 0x00 * 23
username,
Packet.lcs(scrambled_password),
]
pack = "VVCa23Z*A*"
if databasename
data.push databasename
pack.concat "Z*"
end
data.push auth_plugin
pack.concat "Z*"
data.pack(pack)
end
end
# TLS Authentication packet
class TlsAuthenticationPacket
def self.serialize(client_flags, max_packet_size, charset_number)
[
client_flags,
max_packet_size,
charset_number,
"", # always 0x00 * 23
].pack("VVCa23")
end
end
# Execute packet
class ExecutePacket
def self.serialize(statement_id, cursor_type, values)
nbm = null_bitmap values
netvalues = ""
types = values.map do |v|
t, n = Protocol.value2net v
netvalues.concat n if v
t
end
[Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
end
# make null bitmap
#
# If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
def self.null_bitmap(values)
bitmap = values.enum_for(:each_slice,8).map do |vals|
vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
end
return bitmap.pack("C*")
end
end
class AuthenticationResultPacket
def self.parse(pkt)
result = pkt.utiny
auth_plugin = pkt.string
scramble = pkt.string
self.new(result, auth_plugin, scramble)
end
attr_reader :result, :auth_plugin, :scramble
def initialize(*args)
@result, @auth_plugin, @scramble = args
end
end
end
class RawRecord
def initialize(packet, fields, encoding)
@packet, @fields, @encoding = packet, fields, encoding
end
def to_a
@fields.map do |f|
if s = @packet.lcs
unless f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
s = Charset.convert_encoding(s, @encoding)
end
end
s
end
end
end
class StmtRawRecord
# === Argument
# pkt :: [Packet]
# fields :: [Array of Fields]
# encoding:: [Encoding]
def initialize(packet, fields, encoding)
@packet, @fields, @encoding = packet, fields, encoding
end
# Parse statement result packet
# === Return
# [Array of Object] one record
def parse_record_packet
@packet.utiny # skip first byte
null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
rec = @fields.each_with_index.map do |f, i|
if null_bit_map[i+2] == ?1
nil
else
unsigned = f.flags & Field::UNSIGNED_FLAG != 0
v = Protocol.net2value(@packet, f.type, unsigned)
if v.is_a? Numeric or v.is_a? Mysql::Time
v
elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
Charset.to_binary(v)
else
Charset.convert_encoding(v, @encoding)
end
end
end
rec
end
alias to_a parse_record_packet
end
end