diff --git a/ext/mysql2/mysql2_ext.c b/ext/mysql2/mysql2_ext.c index db06b32..1d6ba36 100644 --- a/ext/mysql2/mysql2_ext.c +++ b/ext/mysql2/mysql2_ext.c @@ -455,6 +455,21 @@ static VALUE init_connection(VALUE self) return self; } +/* call-seq: client.create_statement # => Mysql2::Statement + * + * Create a new prepared statement. + */ +static VALUE create_statement(VALUE self) +{ + MYSQL * client; + MYSQL_STMT * stmt; + + Data_Get_Struct(self, MYSQL, client); + stmt = mysql_stmt_init(client); + + return Data_Wrap_Struct(cMysql2Statement, 0, mysql_stmt_close, stmt); +} + /* Ruby Extension initializer */ void Init_mysql2() { mMysql2 = rb_define_module("Mysql2"); @@ -471,6 +486,7 @@ void Init_mysql2() { rb_define_method(cMysql2Client, "async_result", rb_mysql_client_async_result, 0); rb_define_method(cMysql2Client, "last_id", rb_mysql_client_last_id, 0); rb_define_method(cMysql2Client, "affected_rows", rb_mysql_client_affected_rows, 0); + rb_define_method(cMysql2Client, "create_statement", create_statement, 0); rb_define_private_method(cMysql2Client, "reconnect=", set_reconnect, 1); rb_define_private_method(cMysql2Client, "connect_timeout=", set_connect_timeout, 1); @@ -482,6 +498,7 @@ void Init_mysql2() { cMysql2Error = rb_const_get(mMysql2, rb_intern("Error")); init_mysql2_result(); + init_mysql2_statement(); sym_id = ID2SYM(rb_intern("id")); sym_version = ID2SYM(rb_intern("version")); diff --git a/ext/mysql2/mysql2_ext.h b/ext/mysql2/mysql2_ext.h index 0c18884..d2a915e 100644 --- a/ext/mysql2/mysql2_ext.h +++ b/ext/mysql2/mysql2_ext.h @@ -27,6 +27,7 @@ #endif #include +#include extern VALUE mMysql2; diff --git a/ext/mysql2/statement.c b/ext/mysql2/statement.c new file mode 100644 index 0000000..5a95c13 --- /dev/null +++ b/ext/mysql2/statement.c @@ -0,0 +1,27 @@ +#include + +VALUE cMysql2Statement; + +/* call-seq: stmt.prepare(sql) + * + * Prepare +sql+ for execution + */ +static VALUE prepare(VALUE self, VALUE sql) +{ + MYSQL_STMT * stmt; + + Data_Get_Struct(self, MYSQL_STMT, stmt); + + if(mysql_stmt_prepare(stmt, StringValuePtr(sql), RSTRING_LEN(sql))) { + rb_raise(cMysql2Error, "%s", mysql_stmt_error(stmt)); + } + + return self; +} + +void init_mysql2_statement() +{ + cMysql2Statement = rb_define_class_under(mMysql2, "Statement", rb_cObject); + + rb_define_method(cMysql2Statement, "prepare", prepare, 1); +} diff --git a/ext/mysql2/statement.h b/ext/mysql2/statement.h new file mode 100644 index 0000000..cfea189 --- /dev/null +++ b/ext/mysql2/statement.h @@ -0,0 +1,8 @@ +#ifndef MYSQL2_STATEMENT_H +#define MYSQL2_STATEMENT_H + +extern VALUE cMysql2Statement; + +void init_mysql2_statement(); + +#endif diff --git a/spec/mysql2/statement_spec.rb b/spec/mysql2/statement_spec.rb new file mode 100644 index 0000000..df43970 --- /dev/null +++ b/spec/mysql2/statement_spec.rb @@ -0,0 +1,29 @@ +# encoding: UTF-8 +require 'spec_helper' + +describe Mysql2::Statement do + before :all do + @client = Mysql2::Client.new :host => "localhost", :username => "root" + end + + it "should create a statement" do + stmt = @client.create_statement + stmt.should be_kind_of Mysql2::Statement + end + + it "prepares some sql" do + stmt = @client.create_statement + lambda { stmt.prepare 'SELECT 1' }.should_not raise_error + end + + it "return self when prepare some sql" do + stmt = @client.create_statement + stmt.prepare('SELECT 1').should == stmt + end + + it "should raise an exception when server disconnects" do + stmt = @client.create_statement + @client.close + lambda { stmt.prepare 'SELECT 1' }.should raise_error(Mysql2::Error) + end +end