模仿spring在jdbc中为方法添加事务
2013-08-25 13:28 阅读(173)

在spring中,我们可以通过@Transactional注解,同时配合org.springframework.orm.hibernate3.HibernateTemplate类使用为方法添加事务,无论该方法中涉及多少数据库连接,都在该事务的控制之内。比如举个简单的例子,我有个为博客(blog)添加评论(comment)的方法,此时我需要完成两步:1、向数据库中插入一条评论记录  2、修改该评论所属的blog记录中的评论数字段(total_comment)。 这两步必须在同一个事务中进行,否则,一旦某个操作异常都会造成数据的不一致。简单的代码如下:

/**
* 增加评论
* @return
*/
public boolean addComment(Comment comment){
     //第一步,保存评论
     save(comment);
     //第二步,blog的评论数total_comment字段加1
     Blog blog = blogService.getBlogById(comment.getBlogId);
     blog.setTotal_comment(blog.getTotal_comment() + 1);
     update_total(blog );
}

由于save()和update_total()俩个方法都是通过各自的数据库连接(connection)与数据库打交道,所以如果要想让二者在同一个事务的控制下,必须有个机制,保证二者使用的是同一个连接(connection)。在spring中是通过将connection放入本地线程ThreadLocal中,当addComment()方法执行时,最初会创建个连接(connection)放入ThreadLocal,方法内所有涉及到数据库操作的方法都从这个ThreadLocal获得连接,这样就保证大家使用的是同一个connection。同时,由于连接被放入到ThreadLocal中,避免了多个线程执行时,connection的混乱。  

如果使用的是hibernate框架完成底层操作,使用org.springframework.orm.hibernate3.HibernateTemplate类配合@Transactional注解就可完成事务管理。如果使用纯jdbc的话,实现起来也很简单,个人觉得没有必要在使用第三方的东西。


下面的代码是我实现的一个,已经使用在项目中。 分享下。


package com.nec.dao;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;

import javax.naming.Context;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.sql.DataSource;

import org.springframework.beans.factory.annotation.Autowired;



public class GenericDao {
	
	private static ThreadLocal tansaction_connecton_map = new ThreadLocal();
	
	@Autowired
	private DataSourceFactory dataSourceFactory;
	

	/**
	 * 开启事务
	 * @return
	 */
	public void startTransaction()throws SQLException,NamingException{
		Connection con = createConnection();
		con.setAutoCommit(false);
		tansaction_connecton_map.set(con);
	}
	/**
	 * 提交事务
	 * @return
	 */
	public void commit()throws SQLException{
		Object obj = tansaction_connecton_map.get();
		if(obj != null){
			Connection con = (Connection)obj;
			con.commit();
		}
	}
	
	/**
	 * 回滚事务
	 * @return
	 */
	public void rollback()throws SQLException{
		Object obj = tansaction_connecton_map.get();
		if(obj != null){
			Connection con = (Connection)obj;
			con.rollback();
		}
	}
	
	/**
	 * 销毁事务,释放连接
	 * @return
	 * @throws SQLException
	 * @throws NamingException
	 */
	public void distroyTransaction(){
		try {
			Object obj = tansaction_connecton_map.get();
			if(obj != null){
				Connection con = (Connection)obj;
				con.setAutoCommit(true);
				con.close();
			}
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}finally{
			tansaction_connecton_map.set(null);
		}
	}
	
	private Connection createConnection() throws SQLException, NamingException{
		Connection con = this.dataSourceFactory.createDataSource().getConnection();
        return con;
	}

	//获得数据库连接
     public  Connection getConnection() throws SQLException, NamingException{
    	 Connection con = null;
    	 Object obj = tansaction_connecton_map.get(); 
		 if(obj != null){
			con = (Connection)obj;
		 }else{
			 con = createConnection(); 
		 }
         return con;
     }
     
     
     //增加记录
     public int insert(String sql, Object[] params)throws SQLException, Exception{
    	 int key = -1;
    	 Connection con = null;
    	 PreparedStatement stmt = null;
    	 ResultSet rs = null;
    	 try {
			 con = getConnection();
			 stmt = con.prepareStatement(sql);
			 setParams(stmt, params);
			 int nums = stmt.executeUpdate(); //返回受影响的行数
			 if(nums != 0){
				 rs = stmt.getGeneratedKeys();
				 if(rs != null && rs.next()){
					 key = rs.getInt(1);
				 }
			 }
		}finally{
			closeResultSet(rs);
			closeStatement(stmt);
			closeConnection(con);
		}
		return key;
     }
     
     //删除记录
     public int delete(String sql, Object[] params)throws SQLException, Exception{
    	 int nums = -1;
    	 Connection con = null;
    	 PreparedStatement stmt = null;
    	 try {
			 con = getConnection();
			 stmt = con.prepareStatement(sql);
			 setParams(stmt, params);
			 nums = stmt.executeUpdate(); //返回受影响的行数
		}finally{
			closeStatement(stmt);
			closeConnection(con);
		}
    	 return nums;
     }
     
     //修改
     public int update(String sql, Object[] params)throws SQLException,NamingException{
    	 int num = -1;
    	 Connection con = null;
    	 PreparedStatement stmt = null;
    	 try {
			 con = getConnection();
			 stmt = con.prepareStatement(sql);
			 setParams(stmt, params);
			 num = stmt.executeUpdate(); //返回受影响的行数	
		 }finally{
			 closeStatement(stmt);
			 closeConnection(con);
		 }
    	 return num;
     } 
     
   //求和统计
 	public long count(String sql, Object[] params){
 		long count = 0;
 		Connection con = null;
 		PreparedStatement stmt = null;
 		ResultSet rs = null;
 		try {
 			con = this.getConnection();
 			stmt = con.prepareStatement(sql);
 			setParams(stmt, params);
 			rs = stmt.executeQuery();
 			while(rs != null && rs.next()){
 				count = rs.getLong(1);
 			}
 		}catch (Exception e) {
 			// TODO Auto-generated catch block
 			e.printStackTrace();
 		}finally{
 			closeResultSet(rs);
 			closeStatement(stmt);
 			closeConnection(con);
 		}
 		return count;
 	}     
     
     //关闭ResultSet 
     public void closeResultSet(ResultSet rs){
    	 try {
			if(rs != null){
				 rs.close();
			 }
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
     }
     
     //关闭Statement
     public void closeStatement(Statement stmt){
    	 try {
			if(stmt != null){
				 stmt.close();
			 }
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
     }
     
    //归还连接到连接池
     public void closeConnection(Connection con){
    	 try {
        	 Object obj = tansaction_connecton_map.get(); 
    		 if(obj == null && con != null){ //不在事务控制中时直接关闭
    			 con.close();
    		 }
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
     }
     
     //设置参数
     public void setParams(PreparedStatement stmt, Object[] params) throws SQLException{
    	 for(int i = 0; i<params.length;){//int, short, long, float, double,  boolean, byte, char, List
    		 stmt.setObject(i+1, params[i]);
			 i++;
		 }
     }
     
     public List<Object> findFields(String sql,final Object[] params)throws SQLException,NamingException{
 		List<Object> list = null;
 		Connection con = null;
 		PreparedStatement stmt = null;
 		ResultSet rs = null;
 		try {
			con = this.getConnection();
			stmt = con.prepareStatement(sql);
			setParams(stmt, params);
			rs = stmt.executeQuery();
			ResultSetMetaData md = rs.getMetaData();
			int columnCount = md.getColumnCount();
			if(rs != null && rs.next()){
				list = new ArrayList<Object>();
				do{
					Object[] item = new Object[columnCount];
					list.add(item);
					for (int i = 0; i < columnCount; i++) { 
					   item[i] = rs.getObject(i+1);
					}
			    }while(rs.next());
			}
		}finally{
 			closeResultSet(rs);
 			closeStatement(stmt);
 			closeConnection(con);
 		}
 		return list;
     }
     
     /**
      * 专用于前台统计
      * @param dataSourceFactory
      */
     public MyResultSet countNumService(String sql,final Object[] params){
    	MyResultSet myResult = null;
  		Connection con = null;
  		PreparedStatement stmt = null;
  		ResultSet rs = null;
  		try {
  			con = this.getConnection();
  			stmt = con.prepareStatement(sql);
  			setParams(stmt, params);
  			rs = stmt.executeQuery();
  			myResult = new MyResultSet();
  			myResult.setCon(con);
  			myResult.setStmt(stmt);
  			myResult.setRs(rs);
  		}catch (Exception e) {
  			// TODO Auto-generated catch block
  			e.printStackTrace();
  		}
  		return myResult;
     }
     
     public void setDataSourceFactory(DataSourceFactory dataSourceFactory) {
 		this.dataSourceFactory = dataSourceFactory;
 	}
}
其中 DataSourceFactory  是个获得数据源的工厂,大家可以使用任何方式获得。


 这样,在把最初的addComment()方法改造下,就成了:

public boolean addComment(Comment comment){
     try {    
          this.blogDao.startTransaction();//开启事务
          //第一步,保存评论
          save(comment);
          //第二步,blog的评论数total_comment字段加1
          Blog blog = blogService.getBlogById(comment.getBlogId);
          blog.setTotal_comment(blog.getTotal_comment() + 1);
          update_total(blog );
     } catch (Exception e) {
	 // TODO Auto-generated catch block
	 e.printStackTrace();
	 try {
	     this.blogDao.rollback();//回滚事务
	 } catch (SQLException e1) {
	     // TODO Auto-generated catch block
	     e1.printStackTrace();
         }
    }finally{
	 this.blogDao.distroyTransaction();//销毁事务
    }
}

其中  blogDao继承了GenericDao。  

完毕。