diff --git a/ext/mysql2/statement.c b/ext/mysql2/statement.c index 5a95c13..2e78805 100644 --- a/ext/mysql2/statement.c +++ b/ext/mysql2/statement.c @@ -9,7 +9,6 @@ VALUE cMysql2Statement; static VALUE prepare(VALUE self, VALUE sql) { MYSQL_STMT * stmt; - Data_Get_Struct(self, MYSQL_STMT, stmt); if(mysql_stmt_prepare(stmt, StringValuePtr(sql), RSTRING_LEN(sql))) { @@ -19,9 +18,22 @@ static VALUE prepare(VALUE self, VALUE sql) return self; } +/* call-seq: stmt.param_count # => 2 + * + * Returns the number of parameters the prepared statement expects. + */ +static VALUE param_count(VALUE self) +{ + MYSQL_STMT * stmt; + Data_Get_Struct(self, MYSQL_STMT, stmt); + + return ULL2NUM(mysql_stmt_param_count(stmt)); +} + void init_mysql2_statement() { cMysql2Statement = rb_define_class_under(mMysql2, "Statement", rb_cObject); rb_define_method(cMysql2Statement, "prepare", prepare, 1); + rb_define_method(cMysql2Statement, "param_count", param_count, 0); } diff --git a/spec/mysql2/statement_spec.rb b/spec/mysql2/statement_spec.rb index df43970..8306f0b 100644 --- a/spec/mysql2/statement_spec.rb +++ b/spec/mysql2/statement_spec.rb @@ -2,7 +2,7 @@ require 'spec_helper' describe Mysql2::Statement do - before :all do + before :each do @client = Mysql2::Client.new :host => "localhost", :username => "root" end @@ -26,4 +26,13 @@ describe Mysql2::Statement do @client.close lambda { stmt.prepare 'SELECT 1' }.should raise_error(Mysql2::Error) end + + it "should tell us the param count" do + stmt = @client.create_statement + stmt.prepare 'SELECT ?, ?' + stmt.param_count.should == 2 + + stmt.prepare 'SELECT 1' + stmt.param_count.should == 0 + end end