记录oracle回写的几个解决方案

由于用的是spark1.5.1的版本,出现诸多想不到的bug,记录下来,供大家参考。

首先说下我们的需求,是将hive的表进行回写入oracle,必须使用sparksql这种形式,所以就不考虑sqoop,集群的大数据平台没有sqoop组件。必须按照一定的数据格式精准输出,从oracle跑数时用的什么类型,最后回到oracle是什么类型,并且精度是一致的。
由于大数据平台hive中,将date也存为了string,并且hive的string是不指定长度的,难度在此。

1.第一种方案:

由于考虑到不允许访问hive的metadata元信息,所以使用sqlContext.sql读目标表的schema,将其转为rdd,利用读取oracle的系统表获取最终转换的数据类型及长度,重组schema,并将其与rdd重新构成dataframe
使用一个spark.jdbc类的write.jdbc方法
option(“createTableColumnTypes”,”name varchar(200)”)
加上这个属性,来解决最后建表问题。该方法的该属性,经过测试,无法使用于spark1.5.1版本,应为2.2.0版本使用。
代码如下:

package test1

import org.apache.spark.{ SparkContext, SparkConf }
import org.apache.spark.sql._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import oracle.jdbc.driver.OracleDriver
import sun.security.util.Length
import org.apache.spark.sql.types.StringType
import java.util.ArrayList
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
import org.apache.spark.sql.jdbc._
import java.sql.Types

object ojdbcTest {

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("firstTry").setMaster("local");
    val sc = new SparkContext(conf);
    val sqlContext = new HiveContext(sc);

    //控制schame优化
    var df = sqlContext.sql("select * from  ****.BL_E01_REJECTACCOUNT")
    val df1 = df.schema.toArray

    val theJdbcDF = sqlContext.load("jdbc", Map(
      "url" -> "jdbc:oracle:thin:***/*****@//*****/*****",
      "dbtable" -> "( select column_name ,data_type,data_length,data_precision,data_scale from user_tab_cols where table_name ='BL_E01_REJECTACCOUNT' order by COLUMN_ID ) a ",
      "driver" -> "oracle.jdbc.driver.OracleDriver",
      "numPartitions" -> "5",
      "lowerBound" -> "0",
      "upperBound" -> "80000000"))

    val str = theJdbcDF.collect().toArray
    var dateArray = new ArrayBuffer[String]
    var stringArray = new ArrayBuffer[(String, Int)]

    var list = new ArrayList[org.apache.spark.sql.types.StructField]();

    var string = new ArrayList[String]

    for (j <- 0 until str.length) {
      var st = str(j)
      var column_name = st.get(0)
      var data_type = st.get(1)
      var data_length = st.get(2)
      var data_precision = st.get(3)
      var data_scale = st.get(4)
      println(column_name + ":" + data_type + ":" + data_length + ":" + data_precision + data_scale)

      if (data_type.equals("DATE")) {
        dateArray += (column_name.toString())
        string.add(column_name.toString() + " " + data_type.toString())
      }

      if (data_type.equals("NUMBER")) {
        if (data_precision != null) {
          string.add(column_name.toString() + " " + data_type.toString() + s"(${data_precision.toString().toDouble.intValue()},${data_scale.toString().toDouble.intValue()})")
        } else {
          string.add(column_name.toString() + " " + data_type.toString())
        }

      }
      if (data_type.equals("VARCHAR2")) {
        stringArray += ((column_name.toString(), data_length.toString().toDouble.intValue()))
        string.add(column_name.toString() + " " + data_type.toString() + s"(${data_length.toString().toDouble.intValue()})")
      }

    }
    for (i <- 0 until df1.length) {
      var b = df1(i)
      var dataName = b.name
      var dataType = b.dataType
      //          println("字段名"+dataName+"字段类型"+dataType)
      if (dateArray.exists(p => p.equalsIgnoreCase(s"${dataName}"))) {
        dataType = DateType

      }
      var structType = DataTypes.createStructField(dataName, dataType, true)

      list.add(structType)
    }

    val schema = DataTypes.createStructType(list)

    if (dateArray.length > 0) {

      for (m <- 0 until dateArray.length) {
        var mm = dateArray(m).toString()
        println("mm:" + mm)
        var df5 = df.withColumn(s"$mm", df(s"$mm").cast(DateType))
        df = df5
      }
    }

    val rdd = df.toJavaRDD
    val df2 = sqlContext.createDataFrame(rdd, schema);

    df2.printSchema()

    val url = "jdbc:oracle:thin:@//*******/***"
    val table = "test2"
    val user = "***"
    val password = "***"

    val url1="jdbc:oracle:thin:***/***@//***/***"
    val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")

    val a = string.toString()
    val option = a.substring(1, a.length() - 1)
    println(option)

    df2.option("createTableColumnTypes",s"${option}").write.jdbc(url, table, connectionProperties)

    sc.stop()
  }
} 

代码写的比较随意,只是个test类。

2.第二种方案:

由于考虑到之前那些情况,以上方法不适用于1.5.1后面又采用新的办法
使用重写JdbcDialect类中的三个方法进行读写,这个是sql当中获取jdbc数据库类型的方法,重写就可以实现简单转换。

package test1

import org.apache.spark.{ SparkContext, SparkConf }
import org.apache.spark.sql._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import oracle.jdbc.driver.OracleDriver
import sun.security.util.Length
import org.apache.spark.sql.types.StringType
import java.util.ArrayList
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
import org.apache.spark.sql.jdbc._
import java.sql.Types

object ojdbcTest {



    def oracleInit(){

      val dialect:JdbcDialect= new JdbcDialect() {
        override def canHandle(url:String)={
          url.startsWith("jdbc:oracle");
        }
        //读oracle的类型转换方法
        override def getCatalystType(sqlType, typeName, size, md):Option[DataType]={


      }
      //写oracle的类型转换方法
        override def getJDBCType(dt:DataType):Option[org.apache.spark.sql.jdbc.JdbcType]=

         dt match{
            case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
            case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
            case LongType    => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
            case FloatType   => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
            case DoubleType  => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
            case ByteType    => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
            case ShortType   => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
           case StringType  => Some(JdbcType("VARCHAR2(250)", java.sql.Types.VARCHAR))
            case DateType    => Some(JdbcType("DATE", java.sql.Types.DATE))
            case DecimalType.Unlimited => Some(JdbcType("NUMBER",java.sql.Types.NUMERIC))
            case _ => None
          }

      }
      JdbcDialects.registerDialect(dialect);
    }


  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("firstTry").setMaster("local");
    val sc = new SparkContext(conf);
    val sqlContext = new HiveContext(sc);

    //控制schame优化
    var df = sqlContext.sql("select * from  ****.BL_E01_REJECTACCOUNT")
    val df1 = df.schema.toArray

    val theJdbcDF = sqlContext.load("jdbc", Map(
      "url" -> "jdbc:oracle:thin:****/****@//********/claimamdb",
      "dbtable" -> "( select column_name ,data_type,data_length,data_precision,data_scale from user_tab_cols where table_name ='BL_E01_REJECTACCOUNT' order by COLUMN_ID ) a ",
      "driver" -> "oracle.jdbc.driver.OracleDriver",
      "numPartitions" -> "5",
      "lowerBound" -> "0",
      "upperBound" -> "80000000"))

    val str = theJdbcDF.collect().toArray
    var dateArray = new ArrayBuffer[String]
    var stringArray = new ArrayBuffer[(String, Int)]

    var list = new ArrayList[org.apache.spark.sql.types.StructField]();



    for (j <- 0 until str.length) {
      var st = str(j)
      var column_name = st.get(0)
      var data_type = st.get(1)
      var data_length = st.get(2)
      var data_precision = st.get(3)
      var data_scale = st.get(4)
      println(column_name + ":" + data_type + ":" + data_length + ":" + data_precision + data_scale)

      if (data_type.equals("DATE")) {
        dateArray += (column_name.toString())

      }


      if (data_type.equals("VARCHAR2")) {
        stringArray += ((column_name.toString(), data_length.toString().toDouble.intValue()))

      }

    }
    for (i <- 0 until df1.length) {
      var b = df1(i)
      var dataName = b.name
      var dataType = b.dataType
      //          println("字段名"+dataName+"字段类型"+dataType)
      if (dateArray.exists(p => p.equalsIgnoreCase(s"${dataName}"))) {
        dataType = DateType

      }
      var structType = DataTypes.createStructField(dataName, dataType, true)

      list.add(structType)
    }

    val schema = DataTypes.createStructType(list)

    if (dateArray.length > 0) {

      for (m <- 0 until dateArray.length) {
        var mm = dateArray(m).toString()
        println("mm:" + mm)
        var df5 = df.withColumn(s"$mm", df(s"$mm").cast(DateType))
        df = df5
      }
    }

    val rdd = df.toJavaRDD
    val df2 = sqlContext.createDataFrame(rdd, schema);

    df2.printSchema()

    val url = "jdbc:oracle:thin:@//********/claimamdb"
    val table = "test2"
    val user = "****"
    val password = "****"

    val url1="jdbc:oracle:thin:****/****@//********/claimamdb"
    val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")




    oracleInit()
    df2.write.jdbc(url, table, connectionProperties)

    sc.stop()



  }
}

这种方法只能解决简单类型转换,不能够满足我将hive中原先date已经被转为string再转换回oracle的date,因为即便是重写方法一样也不能传进去参数,无法判断哪个string是date型,可以继承logging类重新jdbcUtils,需要读懂源码还是有些复杂的。

3.第三种方案

代码和第一种相同。
将方法改为由于无法使其建表数据类型为精准值,每次写入oracle中string没有长度就会默认255,这种问题,我将其改为使用createjdbctable和insertIntoJDBC(url1, table, true),结果发现该版本的insertintojdbc是有bug的,官方文档上提示

Save this DataFrame to a JDBC database at url under the table name table. Assumes the table already exists and has a compatible schema. If you pass true for overwrite, it will TRUNCATE the table before performing the INSERTs. 

The table must already exist on the database. It must have a schema that is compatible with the schema of this RDD; inserting the rows of the RDD in order via the simple statement INSERT INTO table VALUES (?, ?, ..., ?) should not fail.

结果还会报错表已经存在,经过去国外的网站查询发现,这是一个bug。
查询结果如下
这里写图片描述

这里写图片描述

这里写图片描述

好了看了这么多东西以后,不采用以上方法,该如何将我们的数据精准搞进去。

4.第四种方案

我看了下的oracle数据库最大varchar2长度是4000,我这么考虑一下,利用重写方言的getjdbcType方法将所有string的数据转为4000,保证数据不会被截断,然后利用oracle的jdbc类将我们目标表的建表字符串拿去建表,然后用dataframe写入一张oracle的临时表,其中varchar2都是4000,再利用select将该表数据导入目标表中。

中间date类型我利用系统表的字段判断出来以后,将其转为timestamp类型,在重写的getjdbcType中转为底层的oracle的date类,这样就不会出现日期被截断的问题。

代码如下:

package test1

import org.apache.spark.{ SparkContext, SparkConf }
import org.apache.spark.sql._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import oracle.jdbc.driver.OracleDriver
import sun.security.util.Length
import org.apache.spark.sql.types.StringType
import java.util.ArrayList
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
import org.apache.spark.sql.jdbc._
import java.sql.Types

import java.sql.Connection
import java.sql.DriverManager
object ojdbcTest {



      def oracleInit(){

        val dialect:JdbcDialect= new JdbcDialect() {
          override def canHandle(url:String)={
            url.startsWith("jdbc:oracle");
          }

//       override def getCatalystType(sqlType, typeName, size, md):Option[DataType]={
    //
    //
    //      }
          override def getJDBCType(dt:DataType):Option[org.apache.spark.sql.jdbc.JdbcType]=

            dt match{
              case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
              case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
              case LongType    => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
              case FloatType   => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
              case DoubleType  => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
              case ByteType    => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
              case ShortType   => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
              case StringType  => Some(JdbcType("VARCHAR2(4000)", java.sql.Types.VARCHAR))
              case DateType    => Some(JdbcType("DATE", java.sql.Types.DATE))
              case DecimalType.Unlimited => Some(JdbcType("NUMBER",java.sql.Types.NUMERIC))
              case TimestampType=> Some(JdbcType("DATE",java.sql.Types.DATE))
              case _ => None
            }

        }
         JdbcDialects.registerDialect(dialect);
      }

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("firstTry").setMaster("local");
    val sc = new SparkContext(conf);
    val sqlContext = new HiveContext(sc);

    //控制schame优化
    var df = sqlContext.sql("select * from  ******.BL_E01_REJECTACCOUNT")
    val df1 = df.schema.toArray

    //val customSchema = sparkTargetDF.dtypes.map(x => x._1+" "+x._2).mkString(",").toUpperCase()
    val theJdbcDF = sqlContext.load("jdbc", Map(
      "url" -> "jdbc:oracle:thin:********/********//********/********",
      "dbtable" -> "( select column_name ,data_type,data_length,data_precision,data_scale from user_tab_cols where table_name ='BL_E01_REJECTACCOUNT' order by COLUMN_ID ) a ",
      "driver" -> "oracle.jdbc.driver.OracleDriver",
      "numPartitions" -> "5",
      "lowerBound" -> "0",
      "upperBound" -> "80000000"))

    val str = theJdbcDF.collect().toArray
    var dateArray = new ArrayBuffer[String]
    var stringArray = new ArrayBuffer[(String, Int)]

    var list = new ArrayList[org.apache.spark.sql.types.StructField]();

    var string = new ArrayList[String]

    for (j <- 0 until str.length) {
      var st = str(j)
      var column_name = st.get(0)
      var data_type = st.get(1)
      var data_length = st.get(2)
      var data_precision = st.get(3)
      var data_scale = st.get(4)
      println(column_name + ":" + data_type + ":" + data_length + ":" + data_precision + data_scale)

      if (data_type.equals("DATE")) {
        dateArray += (column_name.toString())
        string.add(column_name.toString() + " " + data_type.toString())
      }

      if (data_type.equals("NUMBER")) {
        if (data_precision != null) {
          string.add(column_name.toString() + " " + data_type.toString() + s"(${data_precision.toString().toDouble.intValue()},${data_scale.toString().toDouble.intValue()})")
        } else {
          string.add(column_name.toString() + " " + data_type.toString())
        }

      }
      if (data_type.equals("VARCHAR2")) {
        stringArray += ((column_name.toString(), data_length.toString().toDouble.intValue()))
        string.add(column_name.toString() + " " + data_type.toString() + s"(${data_length.toString().toDouble.intValue()})")
      }

    }
    for (i <- 0 until df1.length) {
      var b = df1(i)
      var dataName = b.name
      var dataType = b.dataType
      //          println("字段名"+dataName+"字段类型"+dataType)
      if (dateArray.exists(p => p.equalsIgnoreCase(s"${dataName}"))) {
        dataType = TimestampType

      }
      var structType = DataTypes.createStructField(dataName, dataType, true)

      list.add(structType)
    }

    val schema = DataTypes.createStructType(list)

    if (dateArray.length > 0) {

      for (m <- 0 until dateArray.length) {
        var mm = dateArray(m).toString()
        println("mm:" + mm)
        var df5 = df.withColumn(s"$mm", df(s"$mm").cast(TimestampType))
        df = df5
      }
    }

    val rdd = df.toJavaRDD
    val df2 = sqlContext.createDataFrame(rdd, schema);

    df2.printSchema()

    val url = "jdbc:oracle:thin:@//********/********"
    val table = "test2"
    val table1="test3"
    val user = "********"
    val password = "#EDC5tgb"

    val url1 = "jdbc:oracle:thin:********/********//********/********"
    val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")

    val a = string.toString()
    val option = a.substring(1, a.length() - 1)
    println(option)

    oracleInit()

    createJdbcTable(option,table)

    println("create table is finish!")

    df2.write.jdbc(url, table1, connectionProperties)

    insertTable(table,table1)
    println("已导入目标表!")


    sc.stop()
    //option("createTableColumnTypes", "CLAIMNO VARCHAR2(300), comments VARCHAR(1024)")
    //df2.select(df2("POLICYNO")).write.option("createTableColumnTypes", "CLAIMNO VARCHAR2(200)")
    //.jdbc(url, table, connectionProperties)
  }

  def createJdbcTable(option:String,table:String) = {

    val url = "jdbc:oracle:thin:@//********/********"
    //驱动名称
    val driver = "oracle.jdbc.driver.OracleDriver"
    //用户名
    val username = "********"
    //密码
    val password = "#EDC5tgb"
    //初始化数据连接
    var connection: Connection = null
    try {
      //注册Driver
      Class.forName(driver)
      //得到连接
      connection = DriverManager.getConnection(url, username, password)
      val statement = connection.createStatement
      //执行查询语句,并返回结果
      val sql =s"""
        create table ${table}
(
 ${option}
)
        """
      val rs = statement.executeQuery(sql)
      connection.close
    } catch { case e: Exception => e.printStackTrace }
    finally { //关闭连接,释放资源   connection.close     }
    }
  }

  def insertTable(table:String,table1:String){
    val url = "jdbc:oracle:thin:@//********/********"
    //驱动名称
    val driver = "oracle.jdbc.driver.OracleDriver"
    //用户名
    val username = "********"
    //密码
    val password = "*********"
    //初始化数据连接
    var connection: Connection = null
    try {
      //注册Driver
      Class.forName(driver)
      //得到连接
      connection = DriverManager.getConnection(url, username, password)
      val statement = connection.createStatement
      //执行查询语句,并返回结果
      val sql =s"""
        insert into ${table} select * from  ${table1}
        """
      val rs = statement.executeQuery(sql)
      connection.close
    } catch { case e: Exception => e.printStackTrace }
    finally { //关闭连接,释放资源   connection.close     }
    }

  }
}

很多版本上的坑比如说用
write.mode().jdbc()
这个mode给提供的参数无论给什么都会overwirite掉,无论是append还是ignore。查了下源码,savemode被写死为overwrite。
这个问题详细参考:

https://www.2cto.com/net/201609/551130.html

祝大家少走弯路!

猜你喜欢

转载自blog.csdn.net/jxlxxxmz/article/details/80083200